disney/data.py

54 lines
1.1 KiB
Python
Raw Normal View History

2021-10-25 11:47:22 +02:00
import xmltodict
import numpy as np
def raw():
with open("Disney.graphml","r") as f:
return xmltodict.parse(f.read())
def basic_parse():
d=raw()
d=d["graphml"]
d=d["graph"]
nodes=d["node"]
edges=d["edge"]
return nodes,edges
def read():
nodes,edges=basic_parse()
nodes=[{"id":int(zw["@id"]),"data":[float(zx["#text"]) for zx in zw["data"]]} for zw in nodes]
edges=[[int(zw["@source"]),int(zw["@target"])] for zw in edges]
nodes.sort(key=lambda x:x["id"])
nodes=[zw["data"] for zw in nodes]
return nodes,edges
def adj():
n,e=read()
a=np.zeros((len(n),len(n)))
for e1,e2 in e:
a[e1,e2]=1.0
a[e2,e1]=1.0
a=np.expand_dims(a,axis=0)
n=np.array(n)
n=np.expand_dims(n,axis=0)
return a,n
def load_true():
with open("Disney.true","r") as f:
q=f.read().split("\n")
return np.array([";1" in zw for zw in q if len(zw.strip())>2])
if __name__ == "__main__":
a,n=adj()
t=load_true()
#print(t.shape,np.mean(t),np.sum(t))
np.savez_compressed("data",x=n,a=a,t=t)