shubham_nf/train.py

82 lines
2.2 KiB
Python
Raw Permalink Normal View History

2022-11-21 11:02:58 +01:00
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from model import gen_model
from sklearn.metrics import roc_auc_score
from time import time
def with_evaluation_cost(func):
def wrapper(*args, **kwargs):
t0=time()
dic = func(*args, **kwargs)
t1=time()
dic["evaluation_cost"]=t1-t0
return dic
return wrapper
@with_evaluation_cost
def train_one(x,tx,ty,seed=0,lr_epoch_modulo=10,lr_factor=0.8,lr_minima=1e-4,initial_lr=0.001,epochs=1000,batch_size=30,shall_early=True,patience=10,*args,**kwargs):
np.random.seed(seed)
tf.random.set_seed(seed)
model=gen_model(int(x.shape[1]),*args,**kwargs)
model.compile(optimizer=keras.optimizers.Adam(initial_lr))
def shedule(epoch,lr):
if epoch%lr_epoch_modulo==0 and epoch>0 and lr>lr_minima:
return lr*lr_factor
return lr
callbacks=[]
if shall_early:callbacks.append(keras.callbacks.EarlyStopping(patience=patience,restore_best_weights=True))
callbacks.append(keras.callbacks.LearningRateScheduler(shedule))
callbacks.append(keras.callbacks.TerminateOnNaN())
model.fit(x,None,
epochs=epochs,
batch_size=batch_size,
validation_split=0.2,
callbacks=callbacks)
p=model.predict(tx)
try:
loss=model.evaluate(x)
auc=roc_auc_score(ty,-p)
except:
loss=1000000000.0
auc=-1.0
return {"loss":loss,"auc":auc}
def dic_mean(dic):
ret={k:np.mean([d[k] for d in dic]) for k in dic[0]}
ret["min_loss"]=np.min([d["loss"] for d in dic])
ret["opt_auc"]=np.mean([d["auc"] for d in dic if d["loss"]==ret["min_loss"]])
return ret
@with_evaluation_cost
def train_many(x,tx,ty,count=10,*args,**kwargs):
dics=[]
for seed in range(count):
dics.append(train_one(x,tx,ty,seed=seed,*args,**kwargs))
if dics[-1]["auc"]<0.0:
break
return dic_mean(dics)
if __name__=="__main__":
f=np.load("cardio.npz")
x,tx,ty=f["x"],f["tx"],f["ty"]
print(train_one(x,tx,ty,epochs=1000,shall_early=True,patience=10,lr_epoch_modulo=10,lr_factor=0.8,lr_minima=1e-4,initial_lr=0.001,batch_size=30,splits=1))