Update app.py
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ import numpy as np
|
|
18 |
from tensorflow.keras import Sequential, Input
|
19 |
from tensorflow.keras.layers import Dense, Dropout
|
20 |
from tensorflow.keras.optimizers import SGD
|
21 |
-
from tensorflow.keras.models import load_model
|
22 |
|
23 |
import redis
|
24 |
import os
|
@@ -132,8 +132,8 @@ async def train_and_save_model():
|
|
132 |
train_y = list(training[:, 1])
|
133 |
|
134 |
print("Loading or creating model...")
|
135 |
-
if
|
136 |
-
model = load_model('chatbot_model')
|
137 |
else:
|
138 |
model = Sequential()
|
139 |
model.add(Input(shape=(len(train_x[0]),)))
|
@@ -163,7 +163,7 @@ async def train_and_save_model():
|
|
163 |
classes = sorted(list(set(existing_classes + classes)))
|
164 |
r.set('classes', pickle.dumps(classes))
|
165 |
|
166 |
-
model
|
167 |
|
168 |
print("Data and model saved. Re-training...")
|
169 |
|
@@ -179,7 +179,7 @@ class ChatMessage(BaseModel):
|
|
179 |
async def chat(message: ChatMessage):
|
180 |
words = pickle.loads(r.get('words'))
|
181 |
classes = pickle.loads(r.get('classes'))
|
182 |
-
model = load_model('chatbot_model')
|
183 |
|
184 |
sentence_words = nltk.word_tokenize(message.message)
|
185 |
sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
|
|
|
18 |
from tensorflow.keras import Sequential, Input
|
19 |
from tensorflow.keras.layers import Dense, Dropout
|
20 |
from tensorflow.keras.optimizers import SGD
|
21 |
+
from tensorflow.keras.models import load_model, save_model
|
22 |
|
23 |
import redis
|
24 |
import os
|
|
|
132 |
train_y = list(training[:, 1])
|
133 |
|
134 |
print("Loading or creating model...")
|
135 |
+
if os.path.exists('chatbot_model.h5'):
|
136 |
+
model = load_model('chatbot_model.h5')
|
137 |
else:
|
138 |
model = Sequential()
|
139 |
model.add(Input(shape=(len(train_x[0]),)))
|
|
|
163 |
classes = sorted(list(set(existing_classes + classes)))
|
164 |
r.set('classes', pickle.dumps(classes))
|
165 |
|
166 |
+
save_model(model, 'chatbot_model.h5')
|
167 |
|
168 |
print("Data and model saved. Re-training...")
|
169 |
|
|
|
179 |
async def chat(message: ChatMessage):
|
180 |
words = pickle.loads(r.get('words'))
|
181 |
classes = pickle.loads(r.get('classes'))
|
182 |
+
model = load_model('chatbot_model.h5')
|
183 |
|
184 |
sentence_words = nltk.word_tokenize(message.message)
|
185 |
sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
|