Yhhxhfh commited on
Commit
a5501b7
1 Parent(s): 6895902

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -18
app.py CHANGED
@@ -2,9 +2,9 @@ import uvicorn
2
  import nltk
3
  nltk.download('punkt')
4
  nltk.download('wordnet')
5
- nltk.download('punkt_tab')
6
  nltk.download('omw-1.4')
7
  nltk.download('averaged_perceptron_tagger')
 
8
  from nltk.stem import WordNetLemmatizer
9
  from nltk.corpus import wordnet
10
  from tqdm import tqdm
@@ -16,6 +16,8 @@ import random
16
  import asyncio
17
  import concurrent.futures
18
  import multiprocessing
 
 
19
 
20
  import numpy as np
21
  from tensorflow.keras import Sequential
@@ -39,6 +41,7 @@ lemmatizer = WordNetLemmatizer()
39
  redis_password = os.getenv("REDIS_PASSWORD")
40
  r = redis.Redis(host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=redis_password)
41
 
 
42
  def initialize_redis():
43
  global r
44
  try:
@@ -48,6 +51,7 @@ def initialize_redis():
48
  print("Error connecting to Redis. Exiting.")
49
  exit(1)
50
 
 
51
  async def train_and_save_model():
52
  global lemmatizer, r
53
  while True:
@@ -56,16 +60,19 @@ async def train_and_save_model():
56
  documents = []
57
  ignore_words = ['?', '!']
58
 
59
- try:
60
- with open('intents.json') as file:
61
- intents = json.load(file)
62
- except FileNotFoundError:
63
- intents = {"intents": []}
64
- with open('intents.json', 'w') as file:
65
- json.dump(intents, file, indent=4)
66
- print("intents.json created. Please populate it with training data.")
67
- await asyncio.sleep(60)
68
- continue
 
 
 
69
 
70
  print("Loading user questions from Redis...")
71
  if not r.exists('user_questions_loaded'):
@@ -83,7 +90,7 @@ async def train_and_save_model():
83
  classes.append("unknown")
84
  r.set('user_questions_loaded', 1)
85
 
86
- print("Processing intents from intents.json...")
87
  for intent in intents['intents']:
88
  for pattern in intent['patterns']:
89
  w = nltk.word_tokenize(pattern)
@@ -134,8 +141,9 @@ async def train_and_save_model():
134
  train_y = np.array([row[1] for row in training])
135
 
136
  print("Loading or creating model...")
137
- if os.path.exists('chatbot_model.h5'):
138
- model = load_model('chatbot_model.h5')
 
139
  else:
140
  input_layer = Input(shape=(len(train_x[0]),))
141
  layer1 = Dense(128, activation='relu')(input_layer)
@@ -150,16 +158,19 @@ async def train_and_save_model():
150
  model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
151
 
152
  print("Training the model...")
153
- model.fit(train_x, train_y, epochs=1, batch_size=len(train_x), verbose=0, callbacks=[TqdmCallback(verbose=2)])
154
 
155
  print("Saving data to Redis...")
156
  r.set('words', pickle.dumps(words))
157
  r.set('classes', pickle.dumps(classes))
158
 
159
- save_model(model, 'chatbot_model.h5')
 
 
160
 
161
  print("Data and model saved. Re-training...")
162
 
 
163
  def generate_synonym_pattern(patterns):
164
  new_pattern = []
165
  for word in random.choice(patterns).split():
@@ -171,19 +182,23 @@ def generate_synonym_pattern(patterns):
171
  new_pattern.append(word)
172
  return " ".join(new_pattern)
173
 
 
174
  def start_training_loop():
175
  loop = asyncio.new_event_loop()
176
  asyncio.set_event_loop(loop)
177
  loop.run_until_complete(train_and_save_model())
178
 
 
179
  class ChatMessage(BaseModel):
180
  message: str
181
 
 
182
  @app.post("/chat")
183
  async def chat(message: ChatMessage):
184
  words = pickle.loads(r.get('words'))
185
  classes = pickle.loads(r.get('classes'))
186
- model = load_model('chatbot_model.h5')
 
187
 
188
  sentence_words = nltk.word_tokenize(message.message)
189
  sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
@@ -206,11 +221,13 @@ async def chat(message: ChatMessage):
206
 
207
  return return_list
208
 
 
209
  @app.post("/tag")
210
  async def tag_question(question: str, tag: str):
211
  r.set(f"tag:{question}", tag)
212
  return {"message": "Tag saved"}
213
 
 
214
  html_code = """
215
  <!DOCTYPE html>
216
  <html>
@@ -308,9 +325,11 @@ html_code = """
308
  </html>
309
  """
310
 
 
311
  @app.get("/", response_class=HTMLResponse)
312
  async def root():
313
- return html_code
 
314
 
315
  if __name__ == "__main__":
316
  initialize_redis()
 
2
  import nltk
3
  nltk.download('punkt')
4
  nltk.download('wordnet')
 
5
  nltk.download('omw-1.4')
6
  nltk.download('averaged_perceptron_tagger')
7
+ nltk.download('punkt_tab') # Download the punkt_tab resource
8
  from nltk.stem import WordNetLemmatizer
9
  from nltk.corpus import wordnet
10
  from tqdm import tqdm
 
16
  import asyncio
17
  import concurrent.futures
18
  import multiprocessing
19
+ import io
20
+ import os
21
 
22
  import numpy as np
23
  from tensorflow.keras import Sequential
 
41
  redis_password = os.getenv("REDIS_PASSWORD")
42
  r = redis.Redis(host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=redis_password)
43
 
44
+
45
  def initialize_redis():
46
  global r
47
  try:
 
51
  print("Error connecting to Redis. Exiting.")
52
  exit(1)
53
 
54
+
55
  async def train_and_save_model():
56
  global lemmatizer, r
57
  while True:
 
60
  documents = []
61
  ignore_words = ['?', '!']
62
 
63
+ # Check if intents exist in Redis, otherwise load from local file and upload
64
+ if not r.exists('intents'):
65
+ if os.path.exists('intents.json'):
66
+ with open('intents.json') as f:
67
+ intents = json.load(f)
68
+ r.set('intents', json.dumps(intents))
69
+ print("Intents loaded from local file and uploaded to Redis.")
70
+ else:
71
+ intents = {"intents": []}
72
+ r.set('intents', json.dumps(intents))
73
+ print("intents.json not found locally, creating empty intents in Redis.")
74
+ else:
75
+ intents = json.loads(r.get('intents'))
76
 
77
  print("Loading user questions from Redis...")
78
  if not r.exists('user_questions_loaded'):
 
90
  classes.append("unknown")
91
  r.set('user_questions_loaded', 1)
92
 
93
+ print("Processing intents from Redis...")
94
  for intent in intents['intents']:
95
  for pattern in intent['patterns']:
96
  w = nltk.word_tokenize(pattern)
 
141
  train_y = np.array([row[1] for row in training])
142
 
143
  print("Loading or creating model...")
144
+ if r.exists('chatbot_model'):
145
+ with io.BytesIO(r.get('chatbot_model')) as f:
146
+ model = load_model(f)
147
  else:
148
  input_layer = Input(shape=(len(train_x[0]),))
149
  layer1 = Dense(128, activation='relu')(input_layer)
 
158
  model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
159
 
160
  print("Training the model...")
161
+ model.fit(train_x, train_y, epochs=1, batch_size=len(train_x), verbose=0, callbacks=[TqdmCallback(verbose=2)])
162
 
163
  print("Saving data to Redis...")
164
  r.set('words', pickle.dumps(words))
165
  r.set('classes', pickle.dumps(classes))
166
 
167
+ with io.BytesIO() as f:
168
+ save_model(model, f)
169
+ r.set('chatbot_model', f.getvalue())
170
 
171
  print("Data and model saved. Re-training...")
172
 
173
+
174
  def generate_synonym_pattern(patterns):
175
  new_pattern = []
176
  for word in random.choice(patterns).split():
 
182
  new_pattern.append(word)
183
  return " ".join(new_pattern)
184
 
185
+
186
  def start_training_loop():
187
  loop = asyncio.new_event_loop()
188
  asyncio.set_event_loop(loop)
189
  loop.run_until_complete(train_and_save_model())
190
 
191
+
192
  class ChatMessage(BaseModel):
193
  message: str
194
 
195
+
196
  @app.post("/chat")
197
  async def chat(message: ChatMessage):
198
  words = pickle.loads(r.get('words'))
199
  classes = pickle.loads(r.get('classes'))
200
+ with io.BytesIO(r.get('chatbot_model')) as f:
201
+ model = load_model(f)
202
 
203
  sentence_words = nltk.word_tokenize(message.message)
204
  sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
 
221
 
222
  return return_list
223
 
224
+
225
  @app.post("/tag")
226
  async def tag_question(question: str, tag: str):
227
  r.set(f"tag:{question}", tag)
228
  return {"message": "Tag saved"}
229
 
230
+
231
  html_code = """
232
  <!DOCTYPE html>
233
  <html>
 
325
  </html>
326
  """
327
 
328
+
329
  @app.get("/", response_class=HTMLResponse)
330
  async def root():
331
+ return html_code
332
+
333
 
334
  if __name__ == "__main__":
335
  initialize_redis()