//hub-acc querying for undirected graph
#include "ol/pregel-ol-dev.h"
#include "utils/type.h"

#define DEBUG_MODE//comment it out when running experiments
//input line format: (from hubacc_ug_merge.cpp)
//vid \t in_H==false num nb1 nb2 ... h_num h1 h2 ...
//vid \t in_H==true num nb1 nb2 ... dstHVid1 hop1 dstHVid2 hop2 ...

//output line format: src dst \t hop_dist

//logic:
//superstep 1:
//- s sends all entry_vertices v a msg dist(s, v) to activate them;
//- t sends entry list L(t) to aggregator
//superstep 2:
//- v in L(s) gets L(t) from aggregator, compute dist(s, v)+dist(v, u)+dist(u, t) for all u in L(t);
//- v sends the min dist to aggregator

//special case:
//if s is high-deg, look up dist(s, t) from L(s)
//- if found, t is also high-deg, return dist(s, t) and terminate; (technically, s sets its qvalue() accordingly for correct dumping)
//- otherwise, s does not send anything in superstep 1, and get L(t) to compute upperbound
//(t send (t, dist_tt=0) to agg if t is in H, since s might not be in H)

string in_path = "/ol_merged";
string out_path = "/ol_out";
bool use_combiner = true;

//--------------------------------------------------
//Step 1: define static field of vertex: adj-list
struct SPQueryNQValue {
	vector<int> nbs;
	void * list;
	bool in_H;

	void init(bool is_inH) {
		in_H = is_inH;
		if (in_H)
			list = new hash_map<int, int>;
			else list=new vector<intpair>;
		}

		vector<intpair>* get_entry_list() const
		{
			return (vector<intpair>*)list;
		}

		hash_map<int, int>* get_hub_table() const
		{
			return (hash_map<int, int>*)list;
		}

		~SPQueryNQValue()
		{
			if(in_H) delete (hash_map<int, int>*)list;
			else delete (vector<intpair>*)list;
		}
	};

ibinstream & operator<<(ibinstream & m, const SPQueryNQValue & v) {
	m << v.nbs;
	m << v.in_H; //must do this first, later parts decides on this
	if (v.in_H)
		m << *(v.get_hub_table());
	else
		m << *(v.get_entry_list());
	return m;
}

obinstream & operator>>(obinstream & m, SPQueryNQValue & v) {
	m >> v.nbs;
	m >> v.in_H; //must do this first, later parts decides on this
	v.init(v.in_H); //allocate space first !!!
	if (v.in_H)
		m >> *(v.get_hub_table());
	else
		m >> *(v.get_entry_list());
	return m;
}

//--------------------------------------------------
//Step 2: define query type: here, it is intpair (src, dst)

//--------------------------------------------------
//Step 3: define query-specific vertex state: intpair (hop_from_src, hop_to_dst)
int not_reached = -1;

//--------------------------------------------------
//Step 4: define msg type: here, it is char
char fwd_msg = 1; //01
char back_msg = 2; //10
char met_msg = 3; //11

//--------------------------------------------------
//Step 5: define vertex class

struct SPQueryAggField {
	int min;
	vector<intpair> hubgate;
	int hop; //bound
};

ibinstream & operator<<(ibinstream & m, const SPQueryAggField & v) {
	m << v.min;
	m << v.hubgate;
	m << v.hop;
	return m;
}

obinstream & operator>>(obinstream & m, SPQueryAggField & v) {
	m >> v.min;
	m >> v.hubgate;
	m >> v.hop;
	return m;
}

class SPQueryVertex: public VertexOL<VertexID, intpair, SPQueryNQValue, char,
		intpair> {
public:

	//Step 5.1: define UDF1: query -> vertex's query-specific init state
	virtual intpair init_value(intpair& query) {
		intpair pair(not_reached, not_reached);
		if (id == query.v1)
			pair.v1 = 0;
		if (id == query.v2)
			pair.v2 = 0;
		return pair;
	}

	//Step 5.2: vertex.compute
	virtual void compute(MessageContainer& messages) {
		if (superstep() == 1) //only s and t are active
				{
			intpair query = get_query();
			if (query.v1 == query.v2) {
				forceTerminate();
			} else {
				if (id == get_query().v1) {
					if (!nqvalue().in_H) {
#ifdef DEBUG_MODE//@@@@@@@@@@@@@@@
						cout << "s not in H: s -> v in L(s)" << endl; //@@@@@@@@@@@@@@@
#endif//@@@@@@@@@@@@@@@
						//s --dist(s, v)--> L(s)
						vector<intpair>* list = nqvalue().get_entry_list();
						for (vector<intpair>::iterator it = list->begin();
								it != list->end(); it++) {
							int entry = it->v1;
							int dist = it->v2;
							send_message(entry, (char) dist);
						}
					} else //src is in H, do not send to neighbors
					{
						hash_map<int, int> & L_s=*(nqvalue().get_hub_table());
						hash_map<int, int>::iterator it=L_s.find(get_query().v2);
						if(it != L_s.end())
						{
							//special case: t is in H
							int dist_st=it->second;
#ifdef DEBUG_MODE//@@@@@@@@@@@@@@@
							cout<<"s and T in H: s -> agg with "<<dist_st<<endl;//@@@@@@@@@@@@@@@
#endif//@@@@@@@@@@@@@@@
							nqvalue().nbs.push_back(-dist_st-1);//reuse the last item to hold dist_st
							forceTerminate();
						}
					}
				}
				//t --L(t)--> agg, which is done in aggregator
			}
			//s and t cannot vote to halt yet
		}
		else if(superstep()==2) //s, t and all v in L(s) are active
		{
			vector<int> & nbs=nqvalue().nbs;
			if(id==get_query().v1)
			{ //src
				if(!nqvalue().in_H)//high-deg vertex does not propagate msgs
				{
					//forward broadcast
					for(int i=0; i<nbs.size(); i++) send_message(nbs[i], fwd_msg);
				}
				else
				{ //case 1: s is in H
					hash_map<int, int> & L_s=*(nqvalue().get_hub_table());
					hash_map<int, int>::iterator it=L_s.find(get_query().v2);
					if(it == L_s.end())
					{
						//special case: t is not in H
						vector<intpair> & L_t=((SPQueryAggField*)get_agg())->hubgate;
						int min=INT_MAX;
						for(int i=0; i<L_t.size(); i++)
						{
							int u=L_t[i].v1;
							int dist_ut=L_t[i].v2;
							int dist_su=L_s[u];
							int dist=dist_su + dist_ut;
							if(dist<min) min=dist;
						}
						nbs.push_back(-min-1); //reuse the last item to hold min_hop, but set to negative to differentiate from fwd/back propagated vertices
#ifdef DEBUG_MODE//@@@@@@@@@@@@@@@
						cout<<"s in H, t not: compute "<<L_t.size()<<" candidcate distances, min = "<<min<<endl;//@@@@@@@@@@@@@@@
#endif//@@@@@@@@@@@@@@@
					}
				}
			}
			else if(id==get_query().v2)
			{ //dst
				if(!nqvalue().in_H)//high-deg vertex does not propagate msgs
				{
					//backward broadcast
					for(int i=0; i<nbs.size(); i++) send_message(nbs[i], back_msg);
				}
			}
			else
			{ //v is in L(s)
			  //case 2: s is not in H
				int dist_sv=messages[0];
				vector<intpair> & L_t=((SPQueryAggField*)get_agg())->hubgate;
				hash_map<int, int> & L_v=*(nqvalue().get_hub_table());
				int min=INT_MAX;
				for(int i=0; i<L_t.size(); i++)
				{
					int u=L_t[i].v1;
					int dist_ut=L_t[i].v2;
					int dist_vu=L_v[u];
					int dist=dist_sv + dist_vu + dist_ut;
					if(dist<min) min=dist;
				}
				nbs.push_back(-min-1); //reuse the last item to hold min_hop, but set to negative to differentiate from fwd/back propagated vertices
				/*
				 #ifdef DEBUG_MODE//@@@@@@@@@@@@@@@
				 cout<<"v="<<id<<" in H: compute "<<L_t.size()<<" candidcate distances, min = "<<min<<endl;//@@@@@@@@@@@@@@@
				 #endif//@@@@@@@@@@@@@@@
				 */
			}
			vote_to_halt();
		}
		else // superstep > 2
		{
			SPQueryAggField* agg=(SPQueryAggField*)get_agg();
			int ubound=agg->hop;
			if(superstep()-2 > ubound/2) vote_to_halt(); //pruned by upperbound
			else
			{
				char bor=0;
				for(int i=0; i<messages.size(); i++)
				{
					bor|=messages[i];
					if(bor==met_msg) break;
				}
				if((bor & fwd_msg)!=0) //recv msgs from forward propagation
				{
					if(qvalue().v1 == not_reached)
					{
						qvalue().v1 = superstep()-2;
						if(!nqvalue().in_H) //high-deg vertex does not propagate msgs
						{
							//forward broadcast
							vector<int> & nbs=nqvalue().nbs;
							for(int i=0; i<nbs.size(); i++) send_message(nbs[i], fwd_msg);
						}
					}
				}
				if((bor & back_msg)!=0) //recv msgs from backward propagation
				{
					if(qvalue().v2 == not_reached)
					{
						qvalue().v2 = superstep()-2;
						if(!nqvalue().in_H) //high-deg vertex does not propagate msgs
						{
							//backward broadcast
							vector<int> & nbs=nqvalue().nbs;
							for(int i=0; i<nbs.size(); i++) send_message(nbs[i], back_msg);
						}
					}
				}
				//check met?
				if((qvalue().v1 != not_reached) && (qvalue().v2 != not_reached))
				{
					forceTerminate();
				}
				vote_to_halt();
			}
		}
	}
};

//--------------------------------------------------
//Step 6: define aggregator logic
class SPQueryAgg: public Aggregator<SPQueryVertex, SPQueryAggField,
		SPQueryAggField> {
public:
	SPQueryAggField field;

	virtual void init() {
		field.min = INT_MAX;
		if (SPQueryVertex::superstep() == 1) {
			field.hop = INT_MAX;
		} else if (SPQueryVertex::superstep() == 2) {
			field.hop = INT_MAX;
		} else {
			SPQueryAggField& old_field =
					*((SPQueryAggField*) (SPQueryVertex::get_agg()));
			field.hop = old_field.hop;
		}
	}

	virtual void stepPartial(SPQueryVertex* v) {
		if (SPQueryVertex::superstep() == 1) {
			//special case: s and t are in H, compute() already calls forceTerminate()
			if (v->id == SPQueryVertex::get_query().v1) {
				vector<int>& nbs = v->nqvalue().nbs;
				int size = nbs.size();
				if (size > 0 && nbs[size - 1] < 0) {
					//nbs.resize(size-1);//optional, as not used
					field.hop = -nbs[size - 1] - 1; //for dump purpose
				}
			}
			//t --L(t)--> agg
			if (v->id == SPQueryVertex::get_query().v2) {
				if (v->nqvalue().in_H) {
					intpair pair;
					pair.v1 = v->id;
					pair.v2 = 0;
					field.hubgate.push_back(pair);
#ifdef DEBUG_MODE//@@@@@@@@@@@@@@@
					cout << "t in H: send (" << pair.v1 << ", " << pair.v2
							<< ")" << endl; //@@@@@@@@@@@@@@@
#endif//@@@@@@@@@@@@@@@
				} else {
					vector<intpair>* list = v->nqvalue().get_entry_list();
					field.hubgate = *list;
				}
			}
		} else if (SPQueryVertex::superstep() == 2) {
			vector<int> & nbs = v->nqvalue().nbs;
			int size = nbs.size();
			if (size > 0 && nbs[size - 1] < 0) //v is in L(s)
					{
				int cur = -nbs[size - 1] - 1;
				nbs.resize(size - 1); //remove last item
				if (field.hop > cur)
					field.hop = cur;
			}
		} else {
			int dist1 = v->qvalue().v1;
			int dist2 = v->qvalue().v2;
			if (dist1 != -1 && dist2 != -1) //both reachable
					{
				int dist = dist1 + dist2;
				if (dist < field.min)
					field.min = dist;
			}
		}
	}

	virtual void stepFinal(SPQueryAggField* part) {
		if (SPQueryVertex::superstep() == 1) {
			if (field.hop > part->hop)
				field.hop = part->hop; //this is necessary for the case when s and t are all in H
			if (part->hubgate.size() > 0)
				field.hubgate = part->hubgate;
		} else if (SPQueryVertex::superstep() == 2) {
			if (field.hop > part->hop)
				field.hop = part->hop;
		} else {
			if ((*part).min < field.min)
				field.min = (*part).min;
		}
	}

	virtual SPQueryAggField* finishPartial() {
		return &field;
	}

	virtual SPQueryAggField* finishFinal() {
#ifdef DEBUG_MODE//@@@@@@@@@@@@@@@
		cout << "bound = " << field.hop << ", min = " << field.min << endl; //@@@@@@@@@@@@@@@
#endif//@@@@@@@@@@@@@@@
		return &field;
	}
};

//--------------------------------------------------
//Step 7: define worker class
class SPQueryWorkerOL: public WorkerOL_auto<SPQueryVertex, SPQueryAgg> {
public:
	char buf[50];

	SPQueryWorkerOL() :
			WorkerOL_auto<SPQueryVertex, SPQueryAgg>(true) {
	}

	//Step 7.1: UDF: line -> vertex
	virtual SPQueryVertex* toVertex(char* line) {
		char * pch;
		SPQueryVertex* v = new SPQueryVertex();
		pch = strtok(line, "\t");
		v->id = atoi(pch);
		pch = strtok(NULL, " ");
		int in_H = atoi(pch);
		v->nqvalue().init(in_H);
		pch = strtok(NULL, " ");
		int num = atoi(pch);
		for (int i = 0; i < num; i++) {
			pch = strtok(NULL, " ");
			int nb = atoi(pch);
			v->nqvalue().nbs.push_back(nb);
		}
		if (in_H) { //reads in: dstHVid1 hop1 dstHVid2 hop2 ...
			hash_map<int, int>* table=v->nqvalue().get_hub_table();
			while((pch=strtok(NULL, " "))!=NULL)
			{
				int dst=atoi(pch);
				pch=strtok(NULL, " ");
				int hop=atoi(pch);
				(*table)[dst]=hop;
			}
		}
		else
		{ //reads in: h_num h1 h2 ...
			vector<intpair>* list=v->nqvalue().get_entry_list();
			pch=strtok(NULL, " ");
			int hnum=atoi(pch);
			for(int i=0; i<hnum; i++)
			{
				pch=strtok(NULL, " ");
				intpair pair;
				pair.v1=atoi(pch);
				pch=strtok(NULL, " ");
				pair.v2=atoi(pch);
				list->push_back(pair);
			}
		}
		return v;
	}

	//Step 7.2: UDF: query string -> query (src_id)
	virtual intpair toQuery(char* line) {
		char * pch;
		pch = strtok(line, " ");
		int src = atoi(pch);
		pch = strtok(NULL, " ");
		int dst = atoi(pch);
		return intpair(src, dst);
	}

	//Step 7.3: UDF: vertex init
	virtual void init(VertexContainer& vertex_vec) {
		int src = get_query().v1;
		int pos = get_vpos(src);
		if (pos != -1)
			activate(pos);
		//------
		int dst = get_query().v2;
		pos = get_vpos(dst);
		if (pos != -1)
			activate(pos);
	}

	//Step 7.4: UDF: task_dump
	virtual void dump(SPQueryVertex* vertex, BufferedWriter& writer) { //one entry for each meeting vertex, but dist may be the upperbound (which may be output many times)
		if (vertex->id == get_query().v1) //only let one vertex (here is src) output smallest hop-dist
		{
			if (get_query().v1 == get_query().v2) {
				sprintf(buf, "%d %d\0\n", get_query().v1, get_query().v2);
				writer.write(buf);
			} else {
				int cand1 = (get_agg())->min;
				int cand2 = (get_agg())->hop;
				if (cand1 > cand2)
					cand1 = cand2;
				if (cand1 == INT_MAX) sprintf(buf, "%d %d\tnot reachable\n", get_query().v1, get_query().v2);
				else sprintf(buf, "%d %d\t%d\n", get_query().v1, get_query().v2, cand1);
				writer.write(buf);
			}
		}
	}
};

class SPQueryCombiner: public Combiner<char> {
	//combiner won't take effect in superstep 1 (all tgts are different)
public:
	virtual void combine(char & old, const char & new_msg) {
		old |= new_msg;
	}
};

int main(int argc, char* argv[]) {
	WorkerParams param;
	param.input_path = in_path;
	param.output_path = out_path;
	param.force_write = true;
	param.native_dispatcher = false;
	SPQueryWorkerOL worker;
	SPQueryCombiner combiner;
	if (use_combiner)
		worker.setCombiner(&combiner);
	worker.run(param);
	return 0;
}