Update app.py
Browse files
app.py
CHANGED
@@ -121,7 +121,7 @@ async def train_and_save_model():
|
|
121 |
|
122 |
training.append([bag, output_row])
|
123 |
|
124 |
-
if not training:
|
125 |
print("No training data yet. Waiting...")
|
126 |
await asyncio.sleep(60)
|
127 |
continue
|
@@ -139,7 +139,7 @@ async def train_and_save_model():
|
|
139 |
model.add(Dropout(0.5))
|
140 |
model.add(Dense(64, activation='relu'))
|
141 |
model.add(Dropout(0.5))
|
142 |
-
model.add(Dense(len(
|
143 |
|
144 |
sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
|
145 |
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
|
@@ -190,23 +190,21 @@ async def chat(message: ChatMessage):
|
|
190 |
|
191 |
p = model.predict(np.array([bag]))[0]
|
192 |
ERROR_THRESHOLD = 0.25
|
193 |
-
results = [[i,p] for i,p in enumerate(p) if p>ERROR_THRESHOLD]
|
194 |
results.sort(key=lambda x: x[1], reverse=True)
|
195 |
return_list = []
|
196 |
-
for i,p in results:
|
197 |
return_list.append({"intent": classes[i], "probability": str(p)})
|
198 |
|
199 |
r.rpush('user_questions', message.message)
|
200 |
|
201 |
return return_list
|
202 |
|
203 |
-
|
204 |
@app.post("/tag")
|
205 |
async def tag_question(question: str, tag: str):
|
206 |
r.set(f"tag:{question}", tag)
|
207 |
return {"message": "Tag saved"}
|
208 |
|
209 |
-
|
210 |
html_code = """
|
211 |
<!DOCTYPE html>
|
212 |
<html>
|
@@ -304,12 +302,10 @@ html_code = """
|
|
304 |
</html>
|
305 |
"""
|
306 |
|
307 |
-
|
308 |
@app.get("/", response_class=HTMLResponse)
|
309 |
async def root():
|
310 |
return html_code
|
311 |
|
312 |
-
|
313 |
if __name__ == "__main__":
|
314 |
initialize_redis()
|
315 |
training_process = multiprocessing.Process(target=start_training_loop)
|
|
|
121 |
|
122 |
training.append([bag, output_row])
|
123 |
|
124 |
+
if not training:
|
125 |
print("No training data yet. Waiting...")
|
126 |
await asyncio.sleep(60)
|
127 |
continue
|
|
|
139 |
model.add(Dropout(0.5))
|
140 |
model.add(Dense(64, activation='relu'))
|
141 |
model.add(Dropout(0.5))
|
142 |
+
model.add(Dense(len(classes), activation='softmax'))
|
143 |
|
144 |
sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
|
145 |
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
|
|
|
190 |
|
191 |
p = model.predict(np.array([bag]))[0]
|
192 |
ERROR_THRESHOLD = 0.25
|
193 |
+
results = [[i, p] for i, p in enumerate(p) if p > ERROR_THRESHOLD]
|
194 |
results.sort(key=lambda x: x[1], reverse=True)
|
195 |
return_list = []
|
196 |
+
for i, p in results:
|
197 |
return_list.append({"intent": classes[i], "probability": str(p)})
|
198 |
|
199 |
r.rpush('user_questions', message.message)
|
200 |
|
201 |
return return_list
|
202 |
|
|
|
203 |
@app.post("/tag")
|
204 |
async def tag_question(question: str, tag: str):
|
205 |
r.set(f"tag:{question}", tag)
|
206 |
return {"message": "Tag saved"}
|
207 |
|
|
|
208 |
html_code = """
|
209 |
<!DOCTYPE html>
|
210 |
<html>
|
|
|
302 |
</html>
|
303 |
"""
|
304 |
|
|
|
305 |
@app.get("/", response_class=HTMLResponse)
|
306 |
async def root():
|
307 |
return html_code
|
308 |
|
|
|
309 |
if __name__ == "__main__":
|
310 |
initialize_redis()
|
311 |
training_process = multiprocessing.Process(target=start_training_loop)
|