disney/main.py

76 lines
2.3 KiB
Python
Raw Permalink Normal View History

2021-10-25 11:47:22 +02:00
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import grapa as g
from grapa.functionals import *
from grapa.layers import *
from grapa.constants import *
keras=tf.keras
K=keras.backend
from data import adj
A,X=adj()
data=np.concatenate((X,A),axis=2)
def createmodel():
i=Input(shape=data.shape[1:])#define your input
gs=int(data.shape[1])#find the number of initial nodes
param=int(data.shape[2])-gs#find the number of features per node
g=grap(state(gs=gs,param=param))#create a grap object for use in the functional api
g.X,g.A=gcutparam(gs=gs,param1=param,param2=gs)([i])#set the current feature vector and adjacency matrix of your grap object by cutting the input
m=getm()#get the standart constant file
xx=g.X#save this value, since it is the initial comparison variable
g=gnl(g,m)#a single graph update step
oparam=g.s.param#save the number of parameters, to assert that after the ae step this number is still the same
#g=compress(g,m,5,12)#execute one compression step, reducing gs nodes into gs/5 nodes and adding 12 features to each node
#g=gll(g,m,k=4)#a new graph update step. Here gll means relearning of your graph using a topk (k=4) algorithm
#m.decompress="paramlike"#choose paramlike decompression
#g,com,i2=decompress(g,m,5)#decompress by a factor of 5 (gs/5 nodes -> gs nodes)
g=gnl(g,m)#a final graph update step, using the graph generated in the decompression step
g=remparam(g,oparam)#remove to many parameters again
return i,g
return handlereturn(i,xx,com,i2,g.X,False)#this function uses compression input, initial comparison, compressed state,decompression input,decompressed version, run as a variational autoencoder
def prepare():
i,g=createmodel()#returns z1,z2, which are the same since we dont use a vae,
model=Model(i,g)
plot_model(model,to_file="model.png",show_shapes=True)#save model plots
loss=mse(g,K.ones_like(g))#loss is mse
loss=K.mean(loss)
#if shallvae:
# kl_loss=-0.5*K.mean(1+z2-K.square(z1)-K.exp(z2))
# loss+=kl_loss
model.add_loss(loss)
model.compile(Adam(lr=lr))#we use the adam optimizer
model.summary()
#plot_model(vae,to_file=f"{nam}.png",show_shapes=True)
return model
if __name__ == '__main__':
model=prepare()