Ghgg / app.py
Yhhxhfh's picture
Update app.py
a5501b7 verified
raw
history blame
10.5 kB
import uvicorn
import nltk
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt_tab') # Download the punkt_tab resource
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from tqdm import tqdm
from tqdm.keras import TqdmCallback
import json
import pickle
import random
import asyncio
import concurrent.futures
import multiprocessing
import io
import os
import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.models import load_model, save_model
import redis
import os
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
load_dotenv()
app = FastAPI()
lemmatizer = WordNetLemmatizer()
redis_password = os.getenv("REDIS_PASSWORD")
r = redis.Redis(host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=redis_password)
def initialize_redis():
global r
try:
r.ping()
print("Redis connection successful.")
except redis.exceptions.ConnectionError:
print("Error connecting to Redis. Exiting.")
exit(1)
async def train_and_save_model():
global lemmatizer, r
while True:
words = []
classes = []
documents = []
ignore_words = ['?', '!']
# Check if intents exist in Redis, otherwise load from local file and upload
if not r.exists('intents'):
if os.path.exists('intents.json'):
with open('intents.json') as f:
intents = json.load(f)
r.set('intents', json.dumps(intents))
print("Intents loaded from local file and uploaded to Redis.")
else:
intents = {"intents": []}
r.set('intents', json.dumps(intents))
print("intents.json not found locally, creating empty intents in Redis.")
else:
intents = json.loads(r.get('intents'))
print("Loading user questions from Redis...")
if not r.exists('user_questions_loaded'):
user_questions = r.lrange('user_questions', 0, -1)
for question in user_questions:
question = question.decode('utf-8')
try:
existing_tag = r.get(f"tag:{question}").decode('utf-8')
documents.append((nltk.word_tokenize(question), existing_tag))
if existing_tag not in classes:
classes.append(existing_tag)
except AttributeError:
documents.append((nltk.word_tokenize(question), "unknown"))
if "unknown" not in classes:
classes.append("unknown")
r.set('user_questions_loaded', 1)
print("Processing intents from Redis...")
for intent in intents['intents']:
for pattern in intent['patterns']:
w = nltk.word_tokenize(pattern)
words.extend(w)
documents.append((w, intent['tag']))
if intent['tag'] not in classes:
classes.append(intent['tag'])
print(f"Generating synonyms for intent '{intent['tag']}'...")
with multiprocessing.Pool() as pool:
results = []
for _ in tqdm(range(100000), desc="Generating synonyms", leave=False):
if not intent['patterns']:
break
results.append(pool.apply_async(generate_synonym_pattern, (intent['patterns'],)))
for result in results:
new_pattern = result.get()
if new_pattern:
intent['patterns'].append(new_pattern)
words = [lemmatizer.lemmatize(w.lower()) for w in words if w not in ignore_words]
words = sorted(list(set(words)))
classes = sorted(list(set(classes)))
print("Creating training data...")
training = []
output_empty = [0] * len(classes)
for doc in documents:
bag = []
pattern_words = doc[0]
pattern_words = [lemmatizer.lemmatize(word.lower()) for word in pattern_words]
for w in words:
bag.append(1) if w in pattern_words else bag.append(0)
output_row = list(output_empty)
output_row[classes.index(doc[1])] = 1
training.append([bag, output_row])
if not training:
print("No training data yet. Waiting...")
await asyncio.sleep(60)
continue
train_x = np.array([row[0] for row in training])
train_y = np.array([row[1] for row in training])
print("Loading or creating model...")
if r.exists('chatbot_model'):
with io.BytesIO(r.get('chatbot_model')) as f:
model = load_model(f)
else:
input_layer = Input(shape=(len(train_x[0]),))
layer1 = Dense(128, activation='relu')(input_layer)
layer2 = Dropout(0.5)(layer1)
layer3 = Dense(64, activation='relu')(layer2)
layer4 = Dropout(0.5)(layer3)
output_layer = Dense(len(classes), activation='softmax')(layer4)
model = Sequential(layers=[input_layer, layer1, layer2, layer3, layer4, output_layer])
sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
print("Training the model...")
model.fit(train_x, train_y, epochs=1, batch_size=len(train_x), verbose=0, callbacks=[TqdmCallback(verbose=2)])
print("Saving data to Redis...")
r.set('words', pickle.dumps(words))
r.set('classes', pickle.dumps(classes))
with io.BytesIO() as f:
save_model(model, f)
r.set('chatbot_model', f.getvalue())
print("Data and model saved. Re-training...")
def generate_synonym_pattern(patterns):
new_pattern = []
for word in random.choice(patterns).split():
synonyms = wordnet.synsets(word)
if synonyms:
synonym = random.choice(synonyms[0].lemmas()).name()
new_pattern.append(synonym)
else:
new_pattern.append(word)
return " ".join(new_pattern)
def start_training_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(train_and_save_model())
class ChatMessage(BaseModel):
message: str
@app.post("/chat")
async def chat(message: ChatMessage):
words = pickle.loads(r.get('words'))
classes = pickle.loads(r.get('classes'))
with io.BytesIO(r.get('chatbot_model')) as f:
model = load_model(f)
sentence_words = nltk.word_tokenize(message.message)
sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words]
bag = [0] * len(words)
for s in sentence_words:
for i, w in enumerate(words):
if w == s:
bag[i] = 1
p = model.predict(np.array([bag]))[0]
ERROR_THRESHOLD = 0.25
results = [[i, p] for i, p in enumerate(p) if p > ERROR_THRESHOLD]
results.sort(key=lambda x: x[1], reverse=True)
return_list = []
for i, p in results:
return_list.append({"intent": classes[i], "probability": str(p)})
r.rpush('user_questions', message.message)
return return_list
@app.post("/tag")
async def tag_question(question: str, tag: str):
r.set(f"tag:{question}", tag)
return {"message": "Tag saved"}
html_code = """
<!DOCTYPE html>
<html>
<head>
<title>Chatbot</title>
<style>
body {
font-family: sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}
#container {
background-color: #fff;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
padding: 30px;
width: 80%;
max-width: 600px;
}
h1 {
text-align: center;
margin-bottom: 20px;
color: #333;
}
#chatbox {
height: 300px;
overflow-y: auto;
padding: 10px;
border: 1px solid #ccc;
border-radius: 5px;
margin-bottom: 10px;
}
#chatbox p {
margin: 5px 0;
}
#user_input {
width: 100%;
padding: 10px;
border: 1px solid #ccc;
border-radius: 5px;
margin-bottom: 10px;
box-sizing: border-box;
}
button {
background-color: #4CAF50;
color: white;
padding: 10px 20px;
border: none;
border-radius: 5px;
cursor: pointer;
}
</style>
</head>
<body>
<div id="container">
<h1>Chatbot</h1>
<div id="chatbox"></div>
<input type="text" id="user_input" placeholder="Type your message...">
<button onclick="sendMessage()">Send</button>
</div>
<script>
function sendMessage() {
let userInput = document.getElementById('user_input').value;
document.getElementById('user_input').value = '';
fetch('/chat', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({"message": userInput})
})
.then(response => response.json())
.then(data => {
let chatbox = document.getElementById('chatbox');
chatbox.innerHTML += '<p><b>You:</b> ' + userInput + '</p>';
data.forEach(item => {
chatbox.innerHTML += '<p><b>Bot:</b> ' + item.intent + ' (Probability: ' + item.probability + ')</p>';
});
});
}
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
async def root():
return html_code
if __name__ == "__main__":
initialize_redis()
training_process = multiprocessing.Process(target=start_training_loop)
training_process.start()
uvicorn.run(app, host="0.0.0.0", port=7860)