Update app.py
Browse files
app.py
CHANGED
@@ -2,9 +2,9 @@ import uvicorn
|
|
2 |
import nltk
|
3 |
nltk.download('punkt')
|
4 |
nltk.download('wordnet')
|
5 |
-
nltk.download('punkt_tab')
|
6 |
nltk.download('omw-1.4')
|
7 |
nltk.download('averaged_perceptron_tagger')
|
|
|
8 |
from nltk.stem import WordNetLemmatizer
|
9 |
from nltk.corpus import wordnet
|
10 |
from tqdm import tqdm
|
@@ -16,6 +16,8 @@ import random
|
|
16 |
import asyncio
|
17 |
import concurrent.futures
|
18 |
import multiprocessing
|
|
|
|
|
19 |
|
20 |
import numpy as np
|
21 |
from tensorflow.keras import Sequential
|
@@ -39,6 +41,7 @@ lemmatizer = WordNetLemmatizer()
|
|
39 |
redis_password = os.getenv("REDIS_PASSWORD")
|
40 |
r = redis.Redis(host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=redis_password)
|
41 |
|
|
|
42 |
def initialize_redis():
|
43 |
global r
|
44 |
try:
|
@@ -48,6 +51,7 @@ def initialize_redis():
|
|
48 |
print("Error connecting to Redis. Exiting.")
|
49 |
exit(1)
|
50 |
|
|
|
51 |
async def train_and_save_model():
|
52 |
global lemmatizer, r
|
53 |
while True:
|
@@ -56,16 +60,19 @@ async def train_and_save_model():
|
|
56 |
documents = []
|
57 |
ignore_words = ['?', '!']
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
69 |
|
70 |
print("Loading user questions from Redis...")
|
71 |
if not r.exists('user_questions_loaded'):
|
@@ -83,7 +90,7 @@ async def train_and_save_model():
|
|
83 |
classes.append("unknown")
|
84 |
r.set('user_questions_loaded', 1)
|
85 |
|
86 |
-
print("Processing intents from
|
87 |
for intent in intents['intents']:
|
88 |
for pattern in intent['patterns']:
|
89 |
w = nltk.word_tokenize(pattern)
|
@@ -134,8 +141,9 @@ async def train_and_save_model():
|
|
134 |
train_y = np.array([row[1] for row in training])
|
135 |
|
136 |
print("Loading or creating model...")
|
137 |
-
if
|
138 |
-
|
|
|
139 |
else:
|
140 |
input_layer = Input(shape=(len(train_x[0]),))
|
141 |
layer1 = Dense(128, activation='relu')(input_layer)
|
@@ -150,16 +158,19 @@ async def train_and_save_model():
|
|
150 |
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
|
151 |
|
152 |
print("Training the model...")
|
153 |
-
model.fit(train_x, train_y, epochs=1, batch_size=len(train_x), verbose=0, callbacks=[TqdmCallback(verbose=2)])
|
154 |
|
155 |
print("Saving data to Redis...")
|
156 |
r.set('words', pickle.dumps(words))
|
157 |
r.set('classes', pickle.dumps(classes))
|
158 |
|
159 |
-
|
|
|
|
|
160 |
|
161 |
print("Data and model saved. Re-training...")
|
162 |
|
|
|
163 |
def generate_synonym_pattern(patterns):
|
164 |
new_pattern = []
|
165 |
for word in random.choice(patterns).split():
|
@@ -171,19 +182,23 @@ def generate_synonym_pattern(patterns):
|
|
171 |
new_pattern.append(word)
|
172 |
return " ".join(new_pattern)
|
173 |
|
|
|
174 |
def start_training_loop():
|
175 |
loop = asyncio.new_event_loop()
|
176 |
asyncio.set_event_loop(loop)
|
177 |
loop.run_until_complete(train_and_save_model())
|
178 |
|
|
|
179 |
class ChatMessage(BaseModel):
|
180 |
message: str
|
181 |
|
|
|
182 |
@app.post("/chat")
|
183 |
async def chat(message: ChatMessage):
|
184 |
words = pickle.loads(r.get('words'))
|
185 |
classes = pickle.loads(r.get('classes'))
|
186 |
-
|
|
|
187 |
|
188 |
sentence_words = nltk.word_tokenize(message.message)
|
189 |
sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
|
@@ -206,11 +221,13 @@ async def chat(message: ChatMessage):
|
|
206 |
|
207 |
return return_list
|
208 |
|
|
|
209 |
@app.post("/tag")
|
210 |
async def tag_question(question: str, tag: str):
|
211 |
r.set(f"tag:{question}", tag)
|
212 |
return {"message": "Tag saved"}
|
213 |
|
|
|
214 |
html_code = """
|
215 |
<!DOCTYPE html>
|
216 |
<html>
|
@@ -308,9 +325,11 @@ html_code = """
|
|
308 |
</html>
|
309 |
"""
|
310 |
|
|
|
311 |
@app.get("/", response_class=HTMLResponse)
|
312 |
async def root():
|
313 |
-
|
|
|
314 |
|
315 |
if __name__ == "__main__":
|
316 |
initialize_redis()
|
|
|
2 |
import nltk
|
3 |
nltk.download('punkt')
|
4 |
nltk.download('wordnet')
|
|
|
5 |
nltk.download('omw-1.4')
|
6 |
nltk.download('averaged_perceptron_tagger')
|
7 |
+
nltk.download('punkt_tab') # Download the punkt_tab resource
|
8 |
from nltk.stem import WordNetLemmatizer
|
9 |
from nltk.corpus import wordnet
|
10 |
from tqdm import tqdm
|
|
|
16 |
import asyncio
|
17 |
import concurrent.futures
|
18 |
import multiprocessing
|
19 |
+
import io
|
20 |
+
import os
|
21 |
|
22 |
import numpy as np
|
23 |
from tensorflow.keras import Sequential
|
|
|
41 |
redis_password = os.getenv("REDIS_PASSWORD")
|
42 |
r = redis.Redis(host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=redis_password)
|
43 |
|
44 |
+
|
45 |
def initialize_redis():
|
46 |
global r
|
47 |
try:
|
|
|
51 |
print("Error connecting to Redis. Exiting.")
|
52 |
exit(1)
|
53 |
|
54 |
+
|
55 |
async def train_and_save_model():
|
56 |
global lemmatizer, r
|
57 |
while True:
|
|
|
60 |
documents = []
|
61 |
ignore_words = ['?', '!']
|
62 |
|
63 |
+
# Check if intents exist in Redis, otherwise load from local file and upload
|
64 |
+
if not r.exists('intents'):
|
65 |
+
if os.path.exists('intents.json'):
|
66 |
+
with open('intents.json') as f:
|
67 |
+
intents = json.load(f)
|
68 |
+
r.set('intents', json.dumps(intents))
|
69 |
+
print("Intents loaded from local file and uploaded to Redis.")
|
70 |
+
else:
|
71 |
+
intents = {"intents": []}
|
72 |
+
r.set('intents', json.dumps(intents))
|
73 |
+
print("intents.json not found locally, creating empty intents in Redis.")
|
74 |
+
else:
|
75 |
+
intents = json.loads(r.get('intents'))
|
76 |
|
77 |
print("Loading user questions from Redis...")
|
78 |
if not r.exists('user_questions_loaded'):
|
|
|
90 |
classes.append("unknown")
|
91 |
r.set('user_questions_loaded', 1)
|
92 |
|
93 |
+
print("Processing intents from Redis...")
|
94 |
for intent in intents['intents']:
|
95 |
for pattern in intent['patterns']:
|
96 |
w = nltk.word_tokenize(pattern)
|
|
|
141 |
train_y = np.array([row[1] for row in training])
|
142 |
|
143 |
print("Loading or creating model...")
|
144 |
+
if r.exists('chatbot_model'):
|
145 |
+
with io.BytesIO(r.get('chatbot_model')) as f:
|
146 |
+
model = load_model(f)
|
147 |
else:
|
148 |
input_layer = Input(shape=(len(train_x[0]),))
|
149 |
layer1 = Dense(128, activation='relu')(input_layer)
|
|
|
158 |
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
|
159 |
|
160 |
print("Training the model...")
|
161 |
+
model.fit(train_x, train_y, epochs=1, batch_size=len(train_x), verbose=0, callbacks=[TqdmCallback(verbose=2)])
|
162 |
|
163 |
print("Saving data to Redis...")
|
164 |
r.set('words', pickle.dumps(words))
|
165 |
r.set('classes', pickle.dumps(classes))
|
166 |
|
167 |
+
with io.BytesIO() as f:
|
168 |
+
save_model(model, f)
|
169 |
+
r.set('chatbot_model', f.getvalue())
|
170 |
|
171 |
print("Data and model saved. Re-training...")
|
172 |
|
173 |
+
|
174 |
def generate_synonym_pattern(patterns):
|
175 |
new_pattern = []
|
176 |
for word in random.choice(patterns).split():
|
|
|
182 |
new_pattern.append(word)
|
183 |
return " ".join(new_pattern)
|
184 |
|
185 |
+
|
186 |
def start_training_loop():
|
187 |
loop = asyncio.new_event_loop()
|
188 |
asyncio.set_event_loop(loop)
|
189 |
loop.run_until_complete(train_and_save_model())
|
190 |
|
191 |
+
|
192 |
class ChatMessage(BaseModel):
|
193 |
message: str
|
194 |
|
195 |
+
|
196 |
@app.post("/chat")
|
197 |
async def chat(message: ChatMessage):
|
198 |
words = pickle.loads(r.get('words'))
|
199 |
classes = pickle.loads(r.get('classes'))
|
200 |
+
with io.BytesIO(r.get('chatbot_model')) as f:
|
201 |
+
model = load_model(f)
|
202 |
|
203 |
sentence_words = nltk.word_tokenize(message.message)
|
204 |
sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
|
|
|
221 |
|
222 |
return return_list
|
223 |
|
224 |
+
|
225 |
@app.post("/tag")
|
226 |
async def tag_question(question: str, tag: str):
|
227 |
r.set(f"tag:{question}", tag)
|
228 |
return {"message": "Tag saved"}
|
229 |
|
230 |
+
|
231 |
html_code = """
|
232 |
<!DOCTYPE html>
|
233 |
<html>
|
|
|
325 |
</html>
|
326 |
"""
|
327 |
|
328 |
+
|
329 |
@app.get("/", response_class=HTMLResponse)
|
330 |
async def root():
|
331 |
+
return html_code
|
332 |
+
|
333 |
|
334 |
if __name__ == "__main__":
|
335 |
initialize_redis()
|