From db113b651b32cb964e367421be6e648e6e7a9ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Kl=C3=BCttermann?= Date: Sun, 19 Sep 2021 19:19:53 +0200 Subject: [PATCH] added multiple outputs (nonorthogonal), sligthly improving the quality --- stroo.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/stroo.py b/stroo.py index 8b66a99..148a600 100644 --- a/stroo.py +++ b/stroo.py @@ -53,6 +53,8 @@ class stroo(): ng=ngramtrafo(data,s.grams, n=s.n) p=s.model.predict(ng) p=(p-s.m)**2 + while len(p.shape)>1: + p=np.mean(p,axis=1) return p @@ -69,24 +71,25 @@ def train_model(data,n=3): ngrams=multigramm(data,n=n) data=ngramtrafo(data,ngrams,n=n) - pm=0.0 - while pm**2<0.0001: + pm=[0.0] + while np.mean(pm)**2<0.0001: #tensorflow stuff (not at all optimised) inp=keras.Input(data.shape[1:]) q=inp q=keras.layers.Dense(10,activation="relu",use_bias=False)(q) - q=keras.layers.Dense(4,activation="relu",use_bias=False)(q) - q=keras.layers.Dense(1,activation="relu",use_bias=False)(q) + q=keras.layers.Dense(7,activation="relu",use_bias=False)(q) + os=3 + q=keras.layers.Dense(os,activation="relu",use_bias=False)(q) model=keras.models.Model(inp,q) model.compile("adam","mse") - model.fit(data,np.ones(len(data),dtype="float"), + model.fit(data,np.ones((len(data),os),dtype="float"), batch_size=100, epochs=50, validation_split=0.1) - pm=np.mean(model.predict(data)) + pm=np.mean(model.predict(data),axis=0) return stroo(model,ngrams,n=n,m=pm)