import sys
import MySQLdb
import networkx as nx

nodes = set()
G=nx.Graph()
if (len(sys.argv) > 1):
	for i in range(0,len(sys.argv[1].split("?"))):
		n2 = sys.argv[1].split("?")[i].split("%")[1]
		nodes.add(n2)
		if not G.has_node(n2):
			G.add_node(n2)

		n1 = sys.argv[1].split("?")[i].split("%")[0]
		nodes.add(n1)
		if not G.has_node(n1):
			G.add_node(n1)
		
		w = sys.argv[1].split("?")[i].split("%")[2]
		G.add_edge(n1,n2,{'weight':(2-float(w))})			

weighted = int(sys.argv[2])

rec = sys.argv[3]
tf = sys.argv[4]



end_nodes = {}
if (len(sys.argv) > 5):
	if sys.argv[5] != '%':
		tmp = sys.argv[5].split("%");
		for i in range(0,len(tmp)):
			n = tmp[i]
			end_nodes[n] = 'rec'

if (len(sys.argv) > 6):
	tmp = sys.argv[6].split("%");
	for i in range(0,len(tmp)):
		n = tmp[i]
		if n in end_nodes and end_nodes[n] == 'rec':
			end_nodes.pop(n)
		else:		
			end_nodes[n] = 'tf'
					

nodes_string = "";
i = 0;
for n in nodes:
	nodes_string += "\'"+n+"\'"
	i+=1
	if(i != len(nodes)):
		nodes_string += ",";




conn = MySQLdb.connect (host = "localhost",
                           user = "hippie",
                           passwd = "modX47g",
                           db = "mschaefer_hippie_v2")


cursor = conn.cursor()
cursor.execute("SELECT * FROM sp_analysis_end_nodes WHERE uniprot_id IN ("+nodes_string+")")
rows = cursor.fetchall()
for row in rows:
	if(row[1] != "b"):
		if rec == '1' and row[1] == 'rec':
			end_nodes[row[0]] = row[1]
		if tf == '1' and row[1] == 'tf':
			end_nodes[row[0]] = row[1]
		
		


#directed_edges: holds all edges through which a path goes
#0: directed
#1: both directions
directed_edges = {}
for e in end_nodes:
	if(end_nodes[e] == 'rec'):
		for f in end_nodes:
			if end_nodes[f] == 'tf': #traceback path
				paths = nx.all_shortest_paths(G, e, f)
				if(weighted == 2):
					paths = nx.all_shortest_paths(G, e, f, weight='weight')
				for p in paths: #, weight='weight'
					prev = ""
					for prot in p:
						if prev != "":
							print prev+"%"+prot
						prev = prot
				
for n in end_nodes:
	print n+" "+end_nodes[n]
