From ed0d441974975a93c07ece444d95797c3e9dcb9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Kl=C3=BCttermann?= Date: Sun, 19 Sep 2021 19:08:06 +0200 Subject: [PATCH] made sligthly more useable --- stroo.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/stroo.py b/stroo.py index 835bca0..8b66a99 100644 --- a/stroo.py +++ b/stroo.py @@ -69,23 +69,24 @@ def train_model(data,n=3): ngrams=multigramm(data,n=n) data=ngramtrafo(data,ngrams,n=n) + pm=0.0 + while pm**2<0.0001: + #tensorflow stuff (not at all optimised) + inp=keras.Input(data.shape[1:]) + q=inp + q=keras.layers.Dense(10,activation="relu",use_bias=False)(q) + q=keras.layers.Dense(4,activation="relu",use_bias=False)(q) + q=keras.layers.Dense(1,activation="relu",use_bias=False)(q) - #tensorflow stuff (not at all optimised) - inp=keras.Input(data.shape[1:]) - q=inp - q=keras.layers.Dense(10,activation="relu",use_bias=False)(q) - q=keras.layers.Dense(4,activation="relu",use_bias=False)(q) - q=keras.layers.Dense(1,activation="relu",use_bias=False)(q) + model=keras.models.Model(inp,q) - model=keras.models.Model(inp,q) - - model.compile("adam","mse") - model.fit(data,np.ones(len(data),dtype="float"), + model.compile("adam","mse") + model.fit(data,np.ones(len(data),dtype="float"), batch_size=100, epochs=50, validation_split=0.1) - pm=np.mean(model.predict(data)) + pm=np.mean(model.predict(data)) return stroo(model,ngrams,n=n,m=pm)