#include "ol/pregel-ol-dev.h"
#include "utils/type.h"

string in_path  = "/toy_ug";
string out_path = "/ol_out";
bool use_combiner = true;

//Da's 2nd round change:
//opt 1: if there's no active forwarding or backwarding vertices, force_terminate
//(it may take long for both to become 0)
//another change: only src reports result during dumping

//input line format for undirected graph vetexID, neighbors number, neighbor1,neighbor2,neighbor3....
//edge length is assigned with 1 means the graph is unweighted

//the output format is met_vetex \t hopcount

//step1: define nonquery value:
//In this case is the adjacent list
struct BiBFSValue{
	char tag; //ADDED FOR OPT1
	vector<int> nbs;
};

ibinstream & operator<<(ibinstream& m, const BiBFSValue &v){
    m << v.nbs;
    return m;
}

obinstream & operator>>(obinstream& m, BiBFSValue &v){
	m >> v.nbs;
	return m;
}

//step2: define query type which is intpair <src,dst>

//step3: define query-specific vetex state(query value) which is intpair<hop_to_src,hop_to_dst>

//step4: define message type which in this case is char
char fwd_message = 1; //1
char back_message = 2; //10
char met_message = 3;  //11


//step5: define vertex class

int not_reached = -1;

class BiBFSVertex : public VertexOL<VertexID,intpair,BiBFSValue,char,intpair>{
public:
	  //step5.1: define query specific setup value
      virtual intpair init_value(intpair& query){
          intpair pair(not_reached,not_reached);
          if(id == query.v1) pair.v1 = 0;
          if(id == query.v2) pair.v2 = 0;
          return pair;
      }

      //step5.2: define compute function
      virtual void compute(MessageContainer& messages){
    	  nqvalue().tag=0;//ADDED FOR OPT1

          if(superstep() == 1){
        	  //if query.v1 = query.v2 terminate
        	  intpair query = get_query();
        	  if(query.v1 == query.v2){
        		  forceTerminate();
        	  }else{
        		  //broadcast from src and dst
        		  if(id == query.v1){
        			  vector<int>& nbs = nqvalue().nbs;
        			  for(int i = 0; i < nbs.size(); ++ i){
                          send_message(nbs[i],fwd_message);
        			  }
        		  }else if(id == query.v2){
        			  vector<int>& nbs = nqvalue().nbs;
        			  for(int i = 0; i < nbs.size(); ++ i){
        				  send_message(nbs[i],back_message);
        			  }
        		  }
        	  }
          }else{
        	  //not the first superstep
              char bor = 0;
              for(int i = 0; i < messages.size(); ++ i){
            	   bor |= messages[i];
            	   if(bor == met_message) break;
              }

              //forward message
              if(bor & fwd_message){
            	  int hop_to_src = qvalue().v1;
            	  if(hop_to_src == not_reached){
            		  nqvalue().tag|=fwd_message;//ADDED FOR OPT1
            		  qvalue().v1 = superstep() - 1;
            		  vector<int>& nbs = nqvalue().nbs;
            		  for(int i = 0; i < nbs.size(); ++ i){
            			  send_message(nbs[i],fwd_message);
            		  }
            	  }
              }
              //backward message
              if(bor & back_message){
            	  int hop_to_dst = qvalue().v2;
            	  if(hop_to_dst == not_reached){
            		  nqvalue().tag|=back_message;//ADDED FOR OPT1
            		  qvalue().v2 = superstep() - 1;
            		  vector<int>& nbs = nqvalue().nbs;
            		  for(int i = 0; i < nbs.size(); ++ i){
            			  send_message(nbs[i],back_message);
            		  }
            	  }
              }
              //check meet
              if(qvalue().v1 != not_reached && qvalue().v2 != not_reached){
            	  forceTerminate();
              }
          }
          vote_to_halt();
      }
};

//step 6: define aggregator
class BiBFSAggregator : public Aggregator<BiBFSVertex, inttriplet, inttriplet>{
   public:
		inttriplet triplet;//ADDED FOR OPT1

		virtual void init(){
			triplet.v1 = INT_MAX; //min
			triplet.v2 = 0; //num_back
			triplet.v3 = 0; //num_fwd
		}

		virtual void stepPartial(BiBFSVertex* v){
		   int hop_to_src = v->qvalue().v1;
		   int hop_to_dst = v->qvalue().v2;
		   if(hop_to_src != not_reached && hop_to_dst != not_reached){
			   int dis = hop_to_src + hop_to_dst;
			   if(dis < triplet.v1) triplet.v1 = dis;
		   }
		   //--------------
		   if(v->nqvalue().tag & back_message) triplet.v2++;
		   if(v->nqvalue().tag & fwd_message) triplet.v3++;
		}

		virtual void stepFinal(inttriplet *part){
		   if(triplet.v1 > part->v1) triplet.v1 = part->v1;
		   triplet.v2 += part->v2;
		   triplet.v3 += part->v3;
		}

		virtual inttriplet* finishPartial(){
		   return &triplet;
		}

		virtual inttriplet* finishFinal(){
			if(BiBFSVertex::superstep() > 1)//ADDED FOR OPT1 //initially, both are 0 but should not stop
				if(triplet.v2 == 0 || triplet.v3 == 0)
					BiBFSVertex::forceTerminate();//ADDED FOR OPT1 //must call vertex's forceTerminate(), not the global.h one
			return &triplet;
       }
};

//step 7: return worker class
class BiBFSWorkerOL: public WorkerOL<BiBFSVertex,BiBFSAggregator>{
public:
	char buf[50];

	BiBFSWorkerOL():WorkerOL<BiBFSVertex, BiBFSAggregator>(true){}

	//step 7.1 line to vertex
	virtual BiBFSVertex* toVertex(char* line){
        BiBFSVertex* v = new BiBFSVertex;
        char *id = strtok(line,"\t");
        v -> id = atoi(id);
        char *nb_nums = strtok(NULL," ");
        int nb_num = atoi(nb_nums);
        for(int i = 0; i < nb_num;++i){
        	char *nb = strtok(NULL," ");
        	v->nqvalue().nbs.push_back(atoi(nb));
        }
        return v;
	}

	//step 7,2 query string to query pair
	virtual intpair toQuery(char* line){
		char* p = strtok(line," ");
		int src = atoi(p);
		p = strtok(NULL," ");
		int dst = atoi(p);
		return intpair(src,dst);
	}

	//step 7.3 activate the first worker
	virtual void init(VertexContainer& vertex_vec){
		int src = get_query().v1;
		int pos = get_vpos(src);
		if(pos != -1) activate(pos);


		int dst = get_query().v2;
		pos = get_vpos(dst);
		if(pos != -1) activate(pos);
	}

	//step 7.4 dump
	/*//old implementation: dump by meeting points
	virtual void dump(BiBFSVertex* v, BufferedWriter& writer){
        intpair pair = v->qvalue();
        if(pair.v1 != not_reached && pair.v2 != not_reached){
        	int hop = pair.v1 + pair.v2;
        	if(hop == *(get_agg())){
                 sprintf(buf,"%d\t%d %d\n",v->id,pair.v1,pair.v2);
                 writer.write(buf);
        	}
        }
	}
	*/
	virtual void dump(BiBFSVertex* v, BufferedWriter& writer){//ANOTHER CHANGE
		if(v->id==get_query().v1)//only let one vertex (here is src) output smallest hop-dist
		{
			inttriplet& triplet = *(get_agg());
			if(triplet.v1 == INT_MAX) sprintf(buf,"%d %d\tnot reachable\n", v->id, get_query().v2);
			else sprintf(buf,"%d %d\t%d\n", v->id, get_query().v2, triplet.v1);
			writer.write(buf);
		}
	}
};


class BiBFSCombiner:public Combiner<char>
{
	public:
		virtual void combine(char & old, const char& new_msg)
		{
			old|=new_msg;
		}
};

int main(int argc, char* argv[]){
	WorkerParams param;
	param.input_path=in_path;
	param.output_path=out_path;
	param.force_write=true;
	param.native_dispatcher=false;
	BiBFSWorkerOL worker;
	BiBFSCombiner combiner;
	if(use_combiner) worker.setCombiner(&combiner);
	worker.run(param);
	return 0;
}