Yhhxhfh commited on
Commit
6c9f8ba
1 Parent(s): 879e7eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -18,7 +18,7 @@ import numpy as np
18
  from tensorflow.keras import Sequential, Input
19
  from tensorflow.keras.layers import Dense, Dropout
20
  from tensorflow.keras.optimizers import SGD
21
- from tensorflow.keras.models import load_model
22
 
23
  import redis
24
  import os
@@ -132,8 +132,8 @@ async def train_and_save_model():
132
  train_y = list(training[:, 1])
133
 
134
  print("Loading or creating model...")
135
- if r.exists('model'):
136
- model = load_model('chatbot_model')
137
  else:
138
  model = Sequential()
139
  model.add(Input(shape=(len(train_x[0]),)))
@@ -163,7 +163,7 @@ async def train_and_save_model():
163
  classes = sorted(list(set(existing_classes + classes)))
164
  r.set('classes', pickle.dumps(classes))
165
 
166
- model.save('chatbot_model')
167
 
168
  print("Data and model saved. Re-training...")
169
 
@@ -179,7 +179,7 @@ class ChatMessage(BaseModel):
179
  async def chat(message: ChatMessage):
180
  words = pickle.loads(r.get('words'))
181
  classes = pickle.loads(r.get('classes'))
182
- model = load_model('chatbot_model')
183
 
184
  sentence_words = nltk.word_tokenize(message.message)
185
  sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
 
18
  from tensorflow.keras import Sequential, Input
19
  from tensorflow.keras.layers import Dense, Dropout
20
  from tensorflow.keras.optimizers import SGD
21
+ from tensorflow.keras.models import load_model, save_model
22
 
23
  import redis
24
  import os
 
132
  train_y = list(training[:, 1])
133
 
134
  print("Loading or creating model...")
135
+ if os.path.exists('chatbot_model.h5'):
136
+ model = load_model('chatbot_model.h5')
137
  else:
138
  model = Sequential()
139
  model.add(Input(shape=(len(train_x[0]),)))
 
163
  classes = sorted(list(set(existing_classes + classes)))
164
  r.set('classes', pickle.dumps(classes))
165
 
166
+ save_model(model, 'chatbot_model.h5')
167
 
168
  print("Data and model saved. Re-training...")
169
 
 
179
  async def chat(message: ChatMessage):
180
  words = pickle.loads(r.get('words'))
181
  classes = pickle.loads(r.get('classes'))
182
+ model = load_model('chatbot_model.h5')
183
 
184
  sentence_words = nltk.word_tokenize(message.message)
185
  sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]