106 lines
1.4 KiB
Python
106 lines
1.4 KiB
Python
|
import tensorflow as tf
|
||
|
from tensorflow import keras
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
|
||
|
from data import data
|
||
|
|
||
|
os.makedirs('./runs', exist_ok=True)
|
||
|
os.makedirs('./imgs', exist_ok=True)
|
||
|
|
||
|
dex=0
|
||
|
if len(sys.argv)>1:
|
||
|
dex=int(sys.argv[1])
|
||
|
|
||
|
|
||
|
seed=np.random.randint(100000)
|
||
|
x=data(1000)
|
||
|
np.random.seed(12)
|
||
|
X=data(10000)
|
||
|
|
||
|
np.random.seed(seed)
|
||
|
|
||
|
|
||
|
|
||
|
inp=keras.layers.Input(shape=x.shape[1:])
|
||
|
q=inp
|
||
|
q=keras.layers.Dense(5,activation='relu')(q)
|
||
|
q=keras.layers.Dense(5,activation='relu')(q)
|
||
|
q=keras.layers.Dense(1,activation='linear')(q)
|
||
|
|
||
|
model=keras.models.Model(inputs=inp,outputs=q)
|
||
|
|
||
|
model.compile(optimizer='adam',loss='mse')
|
||
|
|
||
|
model.fit(x,np.ones(len(x)),
|
||
|
epochs=500,
|
||
|
batch_size=25,
|
||
|
validation_split=0.2,
|
||
|
verbose=1,
|
||
|
shuffle=True,
|
||
|
callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss',patience=10)])
|
||
|
|
||
|
|
||
|
#Evaluation phase
|
||
|
|
||
|
|
||
|
x=X
|
||
|
|
||
|
p=model.predict(x)
|
||
|
mp=np.mean(p)
|
||
|
d=(p-mp)**2
|
||
|
d=np.sqrt(np.mean(d,axis=-1))
|
||
|
|
||
|
np.savez_compressed(f"runs/{dex}",d=d,x=x,p=p,mp=mp)
|
||
|
|
||
|
|
||
|
sx=[(xx,dd) for xx,dd in zip(x,d)]
|
||
|
sx.sort(key=lambda x:x[1])
|
||
|
print(sx[0],sx[-1])
|
||
|
|
||
|
sx=[xx for xx,dd in sx]
|
||
|
sx=np.array(sx)
|
||
|
|
||
|
from plt import plt
|
||
|
|
||
|
|
||
|
col1=[1.0,0.0,0.0]
|
||
|
col2=[0.0,1.0,0.0]
|
||
|
|
||
|
col1,col2=np.array(col1),np.array(col2)
|
||
|
|
||
|
ln=len(sx)
|
||
|
|
||
|
cols=[col1*(i/ln)+col2*(1-i/ln) for i in range(ln)]
|
||
|
|
||
|
|
||
|
|
||
|
plt.scatter(sx[:,0],sx[:,1],c=cols)
|
||
|
|
||
|
plt.savefig(f"imgs/{dex}.png")
|
||
|
|
||
|
#plt.plot(sx[:,0],sx[:,1],'.')
|
||
|
|
||
|
plt.how()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|