/*
 * forest.h
 *
 *  Created on: Apr 11, 2011
 *      Author: jianjun
 */

#ifndef FOREST_H_
#define FOREST_H_

#include <deque>
#include <string>
#include "const.h"

using namespace std;

class Node {
#ifdef VERBOSE
	static int count;
	int id;
#endif
	string center; // The representative of this node.
	char bit; // The hash bit (0/1);
	deque <string> obList;
	int height; // height of the node (>=0).
	deque < Node* > children;
	Node * parent; // the parent node.
public:
	inline Node(int h) {
		center = "";
		height = h;
		parent = nil;
#ifdef VERBOSE
		id = count;
		count ++;
#endif
	}

	~Node() {
		delChildren();
	}

//	Node * shrinkNode(Dist & dist, Global & gb) {
//		Node * p = this;
//		while (p->getSize() > gb.szLimit) {
//			if (p->children.empty()) {
//				break;
//			}else {
//				deque <Node*> ::iterator it;
//				size_t max = 0;
//				Node * maxP = nil;
//				for (it = p->children.begin(); it != p->children.end(); it ++) {
//					Node * c = *it;
//					size_t sz = c->getSize();
//					if (sz > max) {
//						max = sz;
//						maxP = c;
//					}
//				}
//				p = maxP;
//			}
//		}
//		return p;
//	}

	/* Return true if this node contains an object with the substring
	 * "native" in its id.
	 */
	bool hasNative(Dist & dist) {
		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			string ob = *it;
			string strId = dist.getStrId(ob);
			string::size_type pos = strId.find("native");
			if (pos != string::npos) {
				return true;
			}
		}
		return false;
	}

	string getTrace() {
		stringstream trace;
		Node * p = this;
		while (p != nil) {
			trace << ", " << p->getSize();
			trace << ", " << p->getHeight();
			trace << ")";
			p= p->parent;
		}
		return trace.str();
	}

#ifdef VERBOSE
	inline int getId() {
		return id;
	}
#endif

	inline char getBit() {
		return bit;
	}

	inline void setBit(char b) {
		bit = b;
	}

	inline void delChildren() {
		if ( ! children.empty() ) {
			deque < Node* > ::iterator it;
			for (it = children.begin(); it != children.end(); it ++) {
				Node * p = *it;
				delete p;
			}
			children.clear();
		}
	}

	void print(ostream & os) {
#ifdef VERBOSE
		os << "id " << id << "h " << height << " obList.sz " << obList.size() << endl;
#endif
		deque <Node*> ::iterator it;
		for (it = children.begin(); it != children.end(); it ++) {
			Node * p = *it;
			p->print(os);
		}
	}

	inline deque <Node*> & getChildren() {
		return children;
	}

	inline int getHeight() const {
		return height;
	}

	inline int getNumObjects() const {
		return obList.size();
	}

	inline size_t getSize() const {
		return obList.size();
	}


	inline string getCenter() {
		return center;
	}

	inline string getOb() {
		return obList.front();
	}

	inline deque <string> & getObList() {
		return obList;
	}

	inline bool isLeaf() const {
		return children.empty();
	}

	inline Node * getParent() {
		return parent;
	}

	inline void addObject(string & ob) {
		obList.push_back(ob);
	}

	inline void addChild(Node * c) {
		children.push_back(c);
		c->parent = this;
	}

	inline void setObjList(deque <string> & inList) {
		deque <string> ::iterator it;
		for (it = inList.begin(); it != inList.end(); it ++) {
			obList.push_back(*it);
		}
	}

	inline void setObjList(deque <strStrPair> & inList) {
		deque <strStrPair> ::iterator it;
		for (it = inList.begin(); it != inList.end(); it ++) {
			obList.push_back(it->first);
		}
	}

	/* output a node to a file */
	void out(string & outFileName, Dist & dist) {
		ofstream os(outFileName.c_str());

		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			string id = *it;
			os << dist.getStrId(id) << endl;
		}

		os.close();
	}

	/* For debuging purpose. Pick the object with the lowest Energy as representative
	 * of this node.
	 */
	string lowestEnergy(Global & gb, Dist & dist) {
		// To select the one with lowest rmsd.
		float min = 999999.9;
		string minOb;
		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			string ob = *it;
			string s = gb.getEnergy(dist.getStrId(ob));
			float v = atof(s.c_str());
			if (v < min) {
				min = v;
				minOb = ob;
			}
		}
		return minOb;
	}

	/* For debuging purpose. Pick the object with the lowest RMSD to the native as representative
	 * of this node.
	 */
	string lowestRMSD(Global & gb, Dist & dist) {
		// To select the one with lowest rmsd.
		float min = 9999.9;
		string minOb;
		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			string ob = *it;
			string s = gb.getClass(dist.getStrId(ob));
			float v = atof(s.c_str());
			if (v < min) {
				min = v;
				minOb = ob;
			}
		}
		return minOb;
	}

	/* Find the representative of this node by finding the center of it.
	 */
	string compCenter(Dist & dist) {
		//Locate the center
		float min = Dist::maxDist;
		string center = "N/A";
		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			string iOb = *it;
			deque <string> ::iterator jt;
			float ls = 0;
			for (jt = obList.begin(); jt != obList.end(); jt ++) {
				string jOb = *jt;
				float d = dist.get(jOb, iOb);
				ls += d;
			}
			if (ls < min) {
				min = ls;
				center = iOb;
			}
		}
		return center;
	}


	/* Return an object with the lowest total distance to pivots/references */
	string lowestRd(Global & gb) {
		double min = Dist::maxDist;
		string minOb = "";
		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			double score = gb.ob2rd[*it];
			if (score < min) {
				min = score;
				minOb = *it;
			}
		}
		return minOb;
	}

	/* Select a representative from the node */
	string getRep(Global & gb, Dist & dist) {
		string id;

		if (center != "") id = center;
		else {
			id = lowestRd(gb);
			center = id;
		}

		return id;
	}

};

typedef pair <Node*, float> NodeFloatPair;

struct NodeSortProcess : public std::binary_function<NodeFloatPair &, NodeFloatPair &, bool>
{
	bool operator()(NodeFloatPair & a, NodeFloatPair & b) const {
	//return true if left is logically less then right for given comparison
		if(a.second < b.second)
			return true;
		else return false;
	}
};

class Tree {
	Node * root;
	int leafThres; // The number of objects is below this threshold, build a leaf.
	int refArrSize;
	int * refArr;// The index of the list reference objects to use in this tree.
public:
	Tree(int denThres) {
		root = nil;
		leafThres = 1;
		refArr = nil;
	}

	~Tree() {
		delete root;
		if (refArr != nil) {
			delete [] refArr;
		}
	}

	void printRefArr(ostream & os) {
		for (int i = 0; i < refArrSize; i ++) {
			os << refArr[i] << " ";
		}
		os << endl;
	}

	void initRefArr(int hashSize) {
		refArrSize = hashSize;
		refArr = new int[hashSize];
	}

	int * getRefArr() {
		return refArr;
	}

	void print(ostream & os) {
		root->print(os);
	}

	void setLeafThres(int val) {
		leafThres = val;
	}

	/* Build a tree using root.
	 * refHash Map each object to a string of bits (e.g., {1,0}).
	 * objList: the whole list of objects.
	 * refArr: the index of the array of reference objects to use in this tree.
	 * maxLevel: the maximal height (>=0) of the tree is maxLevel - 1. It is also the size of
	 * refArr.
	 */
	void buildTree(Global & gb, map <string, string> & refHash, deque <string> & objList,
			int* refArr, int maxLevel) {
#ifdef VERBOSE
debugout << "To build a tree.." << endl;
#endif
		root = new Node(0);
		root->setObjList(objList);

		deque <Node*> frontline;
		frontline.push_back(root);

		int maxHeight = gb.hashSize;
		int maxSize = objList.size();

		while ( ! frontline.empty()) {
			Node * curP = frontline.front();
			frontline.pop_front();
			int h = curP->getHeight();

			int sz = curP->getSize();
			double a = h / (double)maxHeight;
			double b = log(sz) / log((double)maxSize);

			if (a <= b) { // Only generate half of the tree on the left of b=a
				int refIndex = refArr[h];

				deque <string> & objects = curP->getObList();
				if ((int)objects.size() <= leafThres || h == maxLevel - 1) {// Create a leaf.
					// Currently do nothing
				}else {// Continue to split.
	//				Node * newNodes[gb.splitFactor];
	//
	//				for (int i = 0; i < gb.splitFactor; i ++) {
	//						newNodes[i] = new Node(h + 1);
	//				}
	//
	//				deque <string> ::iterator it;
	//				for (it = objects.begin(); it != objects.end(); it ++) {
	//					string ob = *it;
	//					string & bitStr = refHash[ob];
	//					char bit = bitStr.at(refIndex);
	//					int index = gb.hex2int(bit);
	//					//lists[index].push_back(ob);
	//					newNodes[index]->addObject(ob);
	//				}
	//
	//				for (int i = 0; i < gb.splitFactor; i ++) {
	//					if (newNodes[i]->getNumObjects() > 0) {
	//						curP->addChild(newNodes[i]);
	//						frontline.push_back(newNodes[i]);
	//					}else delete newNodes[i];
	//				}
					deque <string> lists[gb.splitFactor];

					deque <string> ::iterator it;
					for (it = objects.begin(); it != objects.end(); it ++) {
						string ob = *it;
						string & bitStr = refHash[ob];
						char bit = bitStr.at(refIndex);
						int index = gb.hex2int(bit);
						lists[index].push_back(ob);
					}

					for (int i = 0; i < gb.splitFactor; i ++) {
						if (lists[i].size() > 0) {
							Node * newNode = new Node(h + 1);
							newNode->setBit(gb.getBit(i));
							newNode->setObjList(lists[i]);
							curP->addChild(newNode);
							frontline.push_back(newNode);
						}
					}
				}
			}
		}
#ifdef VERBOSE
debugout << "Tree built. Done." << endl;
#endif
	}

	/* Find the leaf node for every string in this tree, and return
	 * the result in loc
	 */
	void locateNode(map <string, Node*> & loc) {
		deque <Node*> q;
		q.push_back(root);
		while( ! q.empty()) {
			Node * p = q.front();
			q.pop_front();

			deque <Node *> & children = p->getChildren();
			if (children.size() > 0) {
				deque <Node *> ::iterator ct;
				for (ct = children.begin(); ct != children.end(); ct ++) {
					q.push_back(*ct);
				}
			}else {// Leaf node.
				deque <string> & obList = p->getObList();
				deque <string> ::iterator it;
				for (it = obList.begin(); it != obList.end(); it ++) {
					loc[*it] = p;
				}
			}
		}
	}

	/* Filter out nodes that are smaller than the leaf limit */
	void filterSmallNode(Global & gb, list <NodeFloatPair> & inList) {
		list <NodeFloatPair> ::iterator it;
		it = inList.begin();
		while (it != inList.end()) {
			Node * p = it->first;
			size_t sz = p->getSize();
			if (sz < gb.leafThres) {
				list <NodeFloatPair> ::iterator dt = it;
				it ++;
				inList.erase(dt);
			}else it ++;
		}
 	}

	/* pick the top objects in the diagonal */
	void pickDiag(Dist & dist, Global & gb, list <NodeFloatPair> & resList) {
		int maxHeight = gb.hashSize;
		int maxSize = dist.getObjNum();

		deque <Node*> q;
		q.push_back(root);
		//list <NodeFloatPair> sortedList;

#ifndef NDEBUG
ofstream os(debugFile);
#endif

		map <Node*, float> nodeInfo;
		while ( ! q.empty()) {
			Node * p = q.front();

			double a = p->getHeight() / (double)maxHeight;
			double b = log(p->getSize()) / log((double)maxSize);
			float d = a - b;//max(1-a, 1-b);
			nodeInfo[p] = d;

			Node * parent = p->getParent();
			map <Node*, float> ::iterator it = nodeInfo.find(parent);
			bool skipChildren = false; // If a node in a branch is already selected, overlook its children
			if (it != nodeInfo.end()) {
				float oldD = it->second;
				if ((oldD <=0 && d >=0) || (oldD >=0 && d <= 0)) { // Across y=x
#ifdef NDEBUG
					skipChildren = true; // For efficiency reason
#endif
					NodeFloatPair nfp(p, - p->getSize());
					//sortedList.push_back(nfp);
					resList.push_back(nfp);
				}
			}

			if ( ! skipChildren) {
				deque <Node*> & children = p->getChildren();
				if ( ! children.empty()) {
					deque <Node*> ::iterator it;
					for (it = children.begin(); it != children.end(); it ++) {
						q.push_back(*it);
					}
				}
			}
			q.pop_front();
		}
	}
};

class Forest {
	deque <Tree*> trees;
public:
	~Forest() {
		deque <Tree*> ::iterator it;
		for (it = trees.begin(); it != trees.end(); it ++) {
			Tree * t = *it;
			delete t;
		}
	}

	inline void addTree(Tree * t) {
		trees.push_back(t);
	}

	/* Select the top nodes in sortedList*/
	void selectNodes(Global & gb, Dist & dist, list <NodeFloatPair> & sortedList,
			deque <Node*> & nodeList) {

		int limit =  gb.sampleSize;

		list <NodeFloatPair> ::iterator jt;
		set <string>idSet;
		int count = 0;
		for (jt = sortedList.begin(); jt != sortedList.end() && count < limit; jt ++) {
			NodeFloatPair & sfp = *jt;
			Node * p = sfp.first;
			nodeList.push_back(p);
			count ++;
		}
	}

	/* Return the center */
	string pickCenter(Global & gb, Dist & dist, vector <string> & inList) {
		double min = dist.maxDist;
		string minOb = "";
		vector <string> ::iterator it;
		for (it = inList.begin(); it != inList.end(); it ++) {
			double ls = 0;
			string iOb = *it;
			vector <string> ::iterator jt;
			for (jt = inList.begin(); jt != inList.end(); jt ++) {
				string jOb = *jt;
				float d = dist.get(iOb, jOb);
				ls += d;
			}
			if (ls < min) {
				min = ls;
				minOb = iOb;
			}
		}
		return minOb;
	}

	/* pick the center
	 *  nodeList: sorted by the average pairwise dist.
	 */
	Node * pickBestDist(Global & gb, Dist & dist, deque <Node*> & nodeList) {
		map <string, Node*> rep2node;// Store the centers of local minimums.
		set <string> s;
		deque <Node*> ::iterator it;
		for (it = nodeList.begin(); it != nodeList.end(); it ++) {
			Node * p = *it;
			string id = p->getRep(gb, dist);
			set <string> ::iterator st = s.find(id);
			if (st == s.end()) {// only include id that has not appeared before.
				rep2node[id] = p;
				s.insert(id);
			}
		}

		float min = dist.maxDist;
		Node * minP = nil;

		map <string, Node*> ::iterator jt;
		for (jt = rep2node.begin(); jt != rep2node.end(); jt ++) {
			{//if (q->getSize() > gb.leafThres) {
				string ob1 = jt->first;
				float ls = dist.bestTotalDist(gb, ob1);

				if (ls < min) {
					min = ls;
					minP = jt->second;
				}
			}
		}
		return minP;
	}


	/* Get the object with the lowest energy */
	string getLowestEnergy(Dist & dist, Global & gb) {
		list <strFloatPair> sortedList;
		deque <string> & obList = dist.getAllObjects();
		deque <string> ::iterator it;
		for (it = obList.begin(); it != obList.end(); it ++) {
			string id = *it;
			string strId = dist.getStrId(id);
			string energy = gb.getEnergy(strId);
			float val = atof(energy.c_str());
			strFloatPair sfp(id, val);
			sortedList.push_back(sfp);
		}
		sortedList.sort(SortProcess());
		return sortedList.front().first;
	}

	/* Compute the center of local minimum */
	void localCenter(Dist & dist, Global & gb, deque <string> & resList) {

		vector <Node *> centers;
		deque <Tree*> :: iterator it;

		int numValidTrees = 0;
		for (it = trees.begin(); it != trees.end(); it ++) {
			Tree * t = *it;
			list <NodeFloatPair> sortedList;
			t->pickDiag(dist, gb, sortedList);

			if (sortedList.size() > 0) {
				numValidTrees ++;
				sortedList.sort(NodeSortProcess());
				deque <Node*> nodeList;
				selectNodes(gb, dist, sortedList, nodeList); // Select local minima

				//double wLsMean;
				Node * p = nil;
				switch (gb.methodId) {
				case USER_LOCAL_CENTER:
				case REF_LOCAL_CENTER:
				case BD_LOCAL_CENTER:
					p = pickBestDist(gb, dist, nodeList);
					break;
				default:
					cout << "Wrong method ID" << endl;
					exit(1);
				}
				centers.push_back(p);
			}
		}

		// Analysis the centers from different trees.
		double ls = 0;

		list <strFloatPair> sfpList;
		vector <Node*> ::iterator vt;
		for (vt = centers.begin(); vt != centers.end(); vt ++) {
			Node * p1 = *vt;
			string ob1 = p1->getRep(gb, dist);
			float ls1 = 0;
			vector <Node*> ::iterator jt;
			for (jt = centers.begin(); jt != centers.end(); jt ++) {
				Node * p2 = *jt;
				string ob2 = p2->getRep(gb, dist);
				float d = dist.get(ob1, ob2);
				ls1 += d;
			}
			ls += ls1;
			//ss += ls1 * ls1;
			strFloatPair sfp(ob1, ls1);
			sfpList.push_back(sfp);
		}

		sfpList.sort(SortProcess());

		set <string> s;
		list <strFloatPair> ::iterator rt;
		int count = 0;

		for (rt = sfpList.begin();
				rt != sfpList.end() && (int)s.size() < gb.sampleSize; rt ++) {
			string id = rt->first;
			set <string> ::iterator st = s.find(id);
			if (st == s.end()) {
				s.insert(id);
				resList.push_back(id);
			}
			if (gb.assisFile != null) { // Output RMSD for results
				string strId = dist.getStrId(id);
				string rmsd = gb.getClass(strId);
				cout << "lc" << count++ << " " << strId << " rmsd " << rmsd;
				double avgLs1 = 0;
				if (centers.size() > 0) avgLs1 = rt->second / centers.size();
				cout << " ls1 " << avgLs1;

				cout << endl;
			}
		}
	}
};

#endif /* FOREST_H_ */
