Yhhxhfh commited on
Commit
73dbfa8
1 Parent(s): b3d9739

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -121,7 +121,7 @@ async def train_and_save_model():
121
 
122
  training.append([bag, output_row])
123
 
124
- if not training:
125
  print("No training data yet. Waiting...")
126
  await asyncio.sleep(60)
127
  continue
@@ -139,7 +139,7 @@ async def train_and_save_model():
139
  model.add(Dropout(0.5))
140
  model.add(Dense(64, activation='relu'))
141
  model.add(Dropout(0.5))
142
- model.add(Dense(len(train_y[0]), activation='softmax'))
143
 
144
  sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
145
  model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
@@ -190,23 +190,21 @@ async def chat(message: ChatMessage):
190
 
191
  p = model.predict(np.array([bag]))[0]
192
  ERROR_THRESHOLD = 0.25
193
- results = [[i,p] for i,p in enumerate(p) if p>ERROR_THRESHOLD]
194
  results.sort(key=lambda x: x[1], reverse=True)
195
  return_list = []
196
- for i,p in results:
197
  return_list.append({"intent": classes[i], "probability": str(p)})
198
 
199
  r.rpush('user_questions', message.message)
200
 
201
  return return_list
202
 
203
-
204
  @app.post("/tag")
205
  async def tag_question(question: str, tag: str):
206
  r.set(f"tag:{question}", tag)
207
  return {"message": "Tag saved"}
208
 
209
-
210
  html_code = """
211
  <!DOCTYPE html>
212
  <html>
@@ -304,12 +302,10 @@ html_code = """
304
  </html>
305
  """
306
 
307
-
308
  @app.get("/", response_class=HTMLResponse)
309
  async def root():
310
  return html_code
311
 
312
-
313
  if __name__ == "__main__":
314
  initialize_redis()
315
  training_process = multiprocessing.Process(target=start_training_loop)
 
121
 
122
  training.append([bag, output_row])
123
 
124
+ if not training:
125
  print("No training data yet. Waiting...")
126
  await asyncio.sleep(60)
127
  continue
 
139
  model.add(Dropout(0.5))
140
  model.add(Dense(64, activation='relu'))
141
  model.add(Dropout(0.5))
142
+ model.add(Dense(len(classes), activation='softmax'))
143
 
144
  sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
145
  model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
 
190
 
191
  p = model.predict(np.array([bag]))[0]
192
  ERROR_THRESHOLD = 0.25
193
+ results = [[i, p] for i, p in enumerate(p) if p > ERROR_THRESHOLD]
194
  results.sort(key=lambda x: x[1], reverse=True)
195
  return_list = []
196
+ for i, p in results:
197
  return_list.append({"intent": classes[i], "probability": str(p)})
198
 
199
  r.rpush('user_questions', message.message)
200
 
201
  return return_list
202
 
 
203
  @app.post("/tag")
204
  async def tag_question(question: str, tag: str):
205
  r.set(f"tag:{question}", tag)
206
  return {"message": "Tag saved"}
207
 
 
208
  html_code = """
209
  <!DOCTYPE html>
210
  <html>
 
302
  </html>
303
  """
304
 
 
305
  @app.get("/", response_class=HTMLResponse)
306
  async def root():
307
  return html_code
308
 
 
309
  if __name__ == "__main__":
310
  initialize_redis()
311
  training_process = multiprocessing.Process(target=start_training_loop)