55 lines
1.3 KiB
Python
55 lines
1.3 KiB
Python
import numpy as np
|
|
|
|
from tensorflow import keras
|
|
from mu import *
|
|
from n2ulayer import ulayer
|
|
|
|
from loss import loss
|
|
|
|
|
|
|
|
def choosenext(given,possble):
|
|
"""given is a list of scores. possble is a list of list of scores. We want to find the combination of elements in possble that has the lowest correlation to given"""
|
|
opt=len(possble)
|
|
np.random.shuffle(possble)
|
|
possble=np.transpose(possble)
|
|
given=np.expand_dims(given,axis=1)
|
|
|
|
|
|
#print("given",given.shape)
|
|
#print("possble",possble.shape)
|
|
#print(loss(given,possble,K=np))
|
|
#exit()
|
|
inp=keras.layers.Input(shape=possble.shape[1:])
|
|
q=inp
|
|
#q=ulayer(opt,0,1)(q)
|
|
q=partr(q,1,opt,ulayer)
|
|
|
|
model=keras.models.Model(inputs=inp,outputs=q)
|
|
model.compile(loss=loss,optimizer=keras.optimizers.Adam(lr=0.001))
|
|
model.summary()
|
|
|
|
model.fit(possble,given,
|
|
batch_size=32,
|
|
epochs=100,
|
|
verbose=1,
|
|
validation_split=0.0,#that stuff cant overfit
|
|
shuffle=True,
|
|
callbacks=[keras.callbacks.EarlyStopping(monitor='loss',patience=10,restore_best_weights=True)])
|
|
|
|
|
|
|
|
|
|
return model.predict(possble)
|
|
|
|
|
|
if __name__=="__main__":
|
|
f=np.load("merged.npz")
|
|
x=f["ps"]
|
|
given=x[0]
|
|
possble=x[1:5]
|
|
choosenext(given,possble)
|
|
|
|
|
|
|