Yhhxhfh commited on
Commit
d388d30
1 Parent(s): 9f4a9e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -18,6 +18,7 @@ import concurrent.futures
18
  import multiprocessing
19
  import io
20
  import os
 
21
 
22
  import numpy as np
23
  from tensorflow.keras import Sequential
@@ -143,8 +144,11 @@ async def train_and_save_model():
143
 
144
  print("Loading or creating model...")
145
  if r.exists('chatbot_model'):
146
- with io.BytesIO(r.get('chatbot_model')) as f:
147
- model = load_model(f)
 
 
 
148
  else:
149
  input_layer = Input(shape=(len(train_x[0]),))
150
  layer1 = Dense(128, activation='relu')(input_layer)
@@ -199,7 +203,11 @@ async def chat(message: ChatMessage):
199
  words = pickle.loads(r.get('words'))
200
  classes = pickle.loads(r.get('classes'))
201
  with io.BytesIO(r.get('chatbot_model')) as f:
202
- model = load_model(f)
 
 
 
 
203
 
204
  sentence_words = nltk.word_tokenize(message.message)
205
  sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
@@ -220,8 +228,7 @@ async def chat(message: ChatMessage):
220
 
221
  r.rpush('user_questions', message.message)
222
 
223
- # Auto-train after receiving a new message
224
- asyncio.create_task(train_and_save_model())
225
 
226
  return return_list
227
 
 
18
  import multiprocessing
19
  import io
20
  import os
21
+ import tempfile
22
 
23
  import numpy as np
24
  from tensorflow.keras import Sequential
 
144
 
145
  print("Loading or creating model...")
146
  if r.exists('chatbot_model'):
147
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
148
+ temp_file.write(r.get('chatbot_model'))
149
+ temp_file_name = temp_file.name
150
+ model = load_model(temp_file_name)
151
+ os.remove(temp_file_name)
152
  else:
153
  input_layer = Input(shape=(len(train_x[0]),))
154
  layer1 = Dense(128, activation='relu')(input_layer)
 
203
  words = pickle.loads(r.get('words'))
204
  classes = pickle.loads(r.get('classes'))
205
  with io.BytesIO(r.get('chatbot_model')) as f:
206
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
207
+ temp_file.write(f.read())
208
+ temp_file_name = temp_file.name
209
+ model = load_model(temp_file_name)
210
+ os.remove(temp_file_name)
211
 
212
  sentence_words = nltk.word_tokenize(message.message)
213
  sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
 
228
 
229
  r.rpush('user_questions', message.message)
230
 
231
+ asyncio.create_task(train_and_save_model())
 
232
 
233
  return return_list
234