#include "mem_graph.h"
#include <stack>
#include <ext/hash_set>

#define hash_set __gnu_cxx::hash_set

using namespace std;

class DFSGraph: public mem_graph {

public:

	int *preorder;
	int *postorder;
	int *scc;
	int num_scc; //Number of SCC in the Graph
	int num_scc_edges;

	DFSGraph(int v, int e, int mode) :
			mem_graph(v, e, mode) {
		preorder = NULL;
		postorder = NULL;
		scc = NULL;
		num_scc = -1;
		num_scc_edges = -1;
	}

	void DFS() {
		preorder = new int[V];
		postorder = new int[V];

		stack<int> vertexStack; //Store the vertices to be handled
		stack<int> nextChildStack; //Store the next children's index to be pushed
		bool *visit = new bool[V];
		memset(visit, 0, V * sizeof(bool)); //Initialize the visit record to be 0

		int nxt_pre = 0;
		int nxt_post = 0;

		for (int pos = 0; pos < V; pos++) //pos is only for scanning the vertices
		{
			if (visit[pos] == 0) { //only process vertices that has never been processed
				int vid = pos + start_mode;
				vertexStack.push(vid);
				nextChildStack.push(0);
				while (!vertexStack.empty()) {
					vid = vertexStack.top();
					pos = vid - start_mode;
					if (!visit[pos]) //first time accessing vid
					{
						preorder[pos] = nxt_pre++;
						visit[pos] = true;
					}
					int size = get_out_size(vid);
					int *adj = get_out_adj(vid);
					if (nextChildStack.top() < size) {
						int childID = adj[nextChildStack.top()];
						if (!visit[childID - start_mode]) {
							vertexStack.push(childID);
							nextChildStack.top()++;
							nextChildStack.push(0);
						}
						else nextChildStack.top()++;
					}
					else {
						postorder[pos] = nxt_post++;
						vertexStack.pop();
						nextChildStack.pop();
					}
				}
			}
		}
		delete visit;
	}

	~DFSGraph() {
		if (preorder != NULL) delete preorder;
		if (postorder != NULL) delete postorder;
		if (scc != NULL) delete scc;
	}

	//======================= below is for SCC =======================

	//Function for setting the record post_array
	void DFS_forward(int *post_array)
	{
		stack<int> vertexStack; //Store the vertices to be handled
		stack<int> nextChildStack; //Store the next children's index to be pushed
		bool *visit = new bool[V];
		memset(visit, 0, V * sizeof(bool)); //Initialize the visit record to be 0

		int time = 0;

		for (int pos = 0; pos < V; pos++) //pos is only for scanning the vertices
		{
			if (visit[pos] == 0) { //only process vertices that has never been processed
				int vid = pos + start_mode;
				vertexStack.push(vid);
				nextChildStack.push(0);
				while (!vertexStack.empty()) {
					vid = vertexStack.top();
					pos = vid - start_mode;
					if (!visit[pos]) //first time accessing vid
					{
						visit[pos] = true;
					}
					int size = get_out_size(vid);
					int *adj = get_out_adj(vid);
					if (nextChildStack.top() < size) {
						int childID = adj[nextChildStack.top()];
						if (!visit[childID - start_mode]) {
							vertexStack.push(childID);
							nextChildStack.top()++;
							nextChildStack.push(0);
						}
						else nextChildStack.top()++;
					}
					else {
						post_array[time] = vid;
						time++;
						vertexStack.pop();
						nextChildStack.pop();
					}
				}
				}
			}
			delete visit;
	}


	//Conduct DFS on the reversed graph in the order of decreasing postorder
	void DFS_reverse(int *post_array)
	{
		stack<int> vertexStack; //Store the vertices to be handled
		stack<int> nextChildStack; //Store the next children's index to be pushed
		bool *visit = new bool[V];
		memset(visit, 0, V * sizeof(bool)); //Initialize the visit record to be 0

		int scc_id = -1;

		for (int time = V-1;time >= 0;time--)		//visit vertices in the order of decreasing postorder
		{
			int pos = post_array[time] - start_mode;
			if (visit[pos] == 0) { 					//meaning new scc found
				scc_id++;
				int vid = pos + start_mode;
				vertexStack.push(vid);
				nextChildStack.push(0);
				while (!vertexStack.empty()) {
					vid = vertexStack.top();
					pos = vid - start_mode;
					if (!visit[pos]) //first time accessing vid
					{
						visit[pos] = true;
					}
					int size = get_in_size(vid);
					int *adj = get_in_adj(vid);
					if (nextChildStack.top() < size) {
						int childID = adj[nextChildStack.top()];
						if (!visit[childID - start_mode]) {
							vertexStack.push(childID);
							nextChildStack.top()++;
							nextChildStack.push(0);
						}
						else nextChildStack.top()++;
					}
					else {
						scc[pos] = scc_id;
						vertexStack.pop();
						nextChildStack.pop();
					}
				}
			}
			}
			num_scc = scc_id + 1;
			delete visit;
	}


	void compute_scc()
	{
		int * post_array=new int[V];	//for record postorder's order
		memset(post_array,0,sizeof(int)*V);
		scc=new int[V];
		DFS_forward(post_array);
		DFS_reverse(post_array);
		delete post_array;
	}

	int get_scc(int vid)
	{
		if (scc) {
			return scc[vid - start_mode];
		}
		else{
			return -1;
		}
	}

	int get_preorder(int vid)
	{
		if (preorder) {
			return preorder[vid - start_mode];
		}else{
			return -1;
		}
	}

	int get_postorder(int vid)
	{
		if (postorder) {
			return postorder[vid - start_mode];
		}else{
			return -1;
		}
	}

	 //=======================Below is for computing condensed graph==============================

	int getSCC_V()
	{
		if(num_scc == -1)
		{
			cerr << "SCC computation not done!" << endl;
			exit(-1);
		}
		return num_scc;
	}

	int getSCC_E()
	{
		if(num_scc_edges == -1)
		{
			cerr << "condensed graph not computed!" << endl;
			exit(-1);
		}
		return num_scc_edges;
	}

	void condensedGraph(string output_filename)
	{
		if(num_scc == -1 || !scc)
		cerr << "SCC computation not done!" << endl;

		hash_set<int> *condenseTable = new hash_set<int>[num_scc];
		for(int pos = 0;pos < V;pos++){
			int vid = pos + start_mode;
			int scc_id = get_scc(vid);
			int size = get_out_size(vid);
			int *adj = get_out_adj(vid);
			for(int i = 0;i < size;i++){
			int uid = adj[i];
			if(scc_id != get_scc(uid))
				condenseTable[scc_id].insert(get_scc(uid));
			}
		}

		ofstream fout(output_filename.c_str());
		num_scc_edges = 0;
		for(int i = 0; i < num_scc; i++){
			int nb_size=condenseTable[i].size();

			num_scc_edges += nb_size;
			fout << i << '\t' << nb_size;
			for(hash_set<int>::iterator it = condenseTable[i].begin(); it != condenseTable[i].end(); it++)
				fout << ' ' << *it;
			fout << endl;
		}
		fout.close();
	}

	void outputDG(string output_filename, bool porder_tag)
	{
		//porder_tag = true, format: vid \t pre post in_num in1 in2 ... out_num out1 out2 ...
		//porder_tag = false, format: vid \t in_num in1 in2 ... out_num out1 out2 ...
		if(porder_tag)
		{
			if(preorder == NULL)
			{
				cerr << "DFS() not done!" << endl;
				exit(-1);
			}
			ofstream fout(output_filename.c_str());
			int end = start_mode + V;
			for(int vid = start_mode; vid < end; vid++){
				int in_deg = get_in_size(vid);
				int *adj_in = get_in_adj(vid);
				fout << vid << '\t' << get_preorder(vid) << " " << get_postorder(vid) << " " << in_deg;
				for(int i=0; i<in_deg; i++) fout << ' ' << adj_in[i];
				int out_deg = get_out_size(vid);
				int *adj_out = get_out_adj(vid);
				fout << ' ' << out_deg;
				for(int i=0; i<out_deg; i++) fout << ' ' << adj_out[i];
				fout << endl;
			}
			fout.close();
		}
		else
		{
			ofstream fout(output_filename.c_str());
			int end = start_mode + V;
			for(int vid = start_mode; vid < end; vid++){
				int in_deg = get_in_size(vid);
				int *adj_in = get_in_adj(vid);
				fout << vid << '\t' << in_deg;
				for(int i=0; i<in_deg; i++) fout << ' ' << adj_in[i];
				int out_deg = get_out_size(vid);
				int *adj_out = get_out_adj(vid);
				fout << ' ' << out_deg;
				for(int i=0; i<out_deg; i++) fout << ' ' << adj_out[i];
				fout << endl;
			}
			fout.close();
		}
	}

};

