made sligthly more useable

This commit is contained in:
Simon Klüttermann 2021-09-19 19:08:06 +02:00
parent c175a9366c
commit ed0d441974
1 changed files with 12 additions and 11 deletions

View File

@ -69,23 +69,24 @@ def train_model(data,n=3):
ngrams=multigramm(data,n=n) ngrams=multigramm(data,n=n)
data=ngramtrafo(data,ngrams,n=n) data=ngramtrafo(data,ngrams,n=n)
pm=0.0
while pm**2<0.0001:
#tensorflow stuff (not at all optimised)
inp=keras.Input(data.shape[1:])
q=inp
q=keras.layers.Dense(10,activation="relu",use_bias=False)(q)
q=keras.layers.Dense(4,activation="relu",use_bias=False)(q)
q=keras.layers.Dense(1,activation="relu",use_bias=False)(q)
#tensorflow stuff (not at all optimised) model=keras.models.Model(inp,q)
inp=keras.Input(data.shape[1:])
q=inp
q=keras.layers.Dense(10,activation="relu",use_bias=False)(q)
q=keras.layers.Dense(4,activation="relu",use_bias=False)(q)
q=keras.layers.Dense(1,activation="relu",use_bias=False)(q)
model=keras.models.Model(inp,q) model.compile("adam","mse")
model.fit(data,np.ones(len(data),dtype="float"),
model.compile("adam","mse")
model.fit(data,np.ones(len(data),dtype="float"),
batch_size=100, batch_size=100,
epochs=50, epochs=50,
validation_split=0.1) validation_split=0.1)
pm=np.mean(model.predict(data)) pm=np.mean(model.predict(data))
return stroo(model,ngrams,n=n,m=pm) return stroo(model,ngrams,n=n,m=pm)