|
import uvicorn |
|
import nltk |
|
nltk.download('punkt') |
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
nltk.download('punkt_tab') |
|
from nltk.stem import WordNetLemmatizer |
|
from nltk.corpus import wordnet |
|
from tqdm import tqdm |
|
|
|
import json |
|
import pickle |
|
import random |
|
import asyncio |
|
import concurrent.futures |
|
|
|
import numpy as np |
|
from tensorflow.keras import Sequential |
|
from tensorflow.keras.layers import Dense, Dropout, Input |
|
from tensorflow.keras.optimizers import SGD |
|
from tensorflow.keras.models import load_model, save_model |
|
|
|
import redis |
|
import os |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI |
|
from fastapi.responses import HTMLResponse |
|
from pydantic import BaseModel |
|
|
|
load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
lemmatizer = WordNetLemmatizer() |
|
|
|
redis_password = os.getenv("REDIS_PASSWORD") |
|
r = redis.Redis(host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=redis_password) |
|
|
|
def initialize_redis(): |
|
global r |
|
try: |
|
r.ping() |
|
print("Redis connection successful.") |
|
except redis.exceptions.ConnectionError: |
|
print("Error connecting to Redis. Exiting.") |
|
exit(1) |
|
|
|
async def train_and_save_model(): |
|
global lemmatizer, r |
|
while True: |
|
words = [] |
|
classes = [] |
|
documents = [] |
|
ignore_words = ['?', '!'] |
|
|
|
try: |
|
with open('intents.json') as file: |
|
intents = json.load(file) |
|
except FileNotFoundError: |
|
intents = {"intents": []} |
|
with open('intents.json', 'w') as file: |
|
json.dump(intents, file, indent=4) |
|
print("intents.json created. Please populate it with training data.") |
|
await asyncio.sleep(60) |
|
continue |
|
|
|
print("Loading user questions from Redis...") |
|
if not r.exists('user_questions_loaded'): |
|
user_questions = r.lrange('user_questions', 0, -1) |
|
for question in user_questions: |
|
question = question.decode('utf-8') |
|
try: |
|
existing_tag = r.get(f"tag:{question}").decode('utf-8') |
|
documents.append((nltk.word_tokenize(question), existing_tag)) |
|
if existing_tag not in classes: |
|
classes.append(existing_tag) |
|
except AttributeError: |
|
documents.append((nltk.word_tokenize(question), "unknown")) |
|
if "unknown" not in classes: |
|
classes.append("unknown") |
|
r.set('user_questions_loaded', 1) |
|
|
|
print("Processing intents from intents.json...") |
|
for intent in intents['intents']: |
|
for pattern in intent['patterns']: |
|
w = nltk.word_tokenize(pattern) |
|
words.extend(w) |
|
documents.append((w, intent['tag'])) |
|
if intent['tag'] not in classes: |
|
classes.append(intent['tag']) |
|
|
|
print(f"Generating synonyms for intent '{intent['tag']}'...") |
|
for _ in tqdm(range(100000), desc="Generating synonyms", leave=False): |
|
if not intent['patterns']: |
|
break |
|
new_pattern = [] |
|
for word in random.choice(intent['patterns']).split(): |
|
synonyms = wordnet.synsets(word) |
|
if synonyms: |
|
synonym = random.choice(synonyms[0].lemmas()).name() |
|
new_pattern.append(synonym) |
|
else: |
|
new_pattern.append(word) |
|
intent['patterns'].append(" ".join(new_pattern)) |
|
|
|
words = [lemmatizer.lemmatize(w.lower()) for w in words if w not in ignore_words] |
|
words = sorted(list(set(words))) |
|
|
|
print("Creating training data...") |
|
training = [] |
|
output_empty = [0] * len(classes) |
|
for doc in documents: |
|
bag = [] |
|
pattern_words = doc[0] |
|
pattern_words = [lemmatizer.lemmatize(word.lower()) for word in pattern_words] |
|
for w in words: |
|
bag.append(1) if w in pattern_words else bag.append(0) |
|
|
|
output_row = list(output_empty) |
|
output_row[classes.index(doc[1])] = 1 |
|
|
|
training.append([bag, output_row]) |
|
|
|
if not training: |
|
print("No training data yet. Waiting...") |
|
await asyncio.sleep(60) |
|
continue |
|
|
|
train_x = np.array([row[0] for row in training]) |
|
train_y = np.array([row[1] for row in training]) |
|
|
|
print("Loading or creating model...") |
|
if os.path.exists('chatbot_model.h5'): |
|
model = load_model('chatbot_model.h5') |
|
else: |
|
input_layer = Input(shape=(len(train_x[0]),)) |
|
layer1 = Dense(128, activation='relu')(input_layer) |
|
layer2 = Dropout(0.5)(layer1) |
|
layer3 = Dense(64, activation='relu')(layer2) |
|
layer4 = Dropout(0.5)(layer3) |
|
output_layer = Dense(len(classes), activation='softmax')(layer4) |
|
|
|
model = Sequential(layers=[input_layer, layer1, layer2, layer3, layer4, output_layer]) |
|
|
|
sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True) |
|
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) |
|
|
|
print("Training the model...") |
|
model.fit(train_x, train_y, epochs=200, batch_size=5) |
|
|
|
print("Saving data to Redis...") |
|
if not r.exists('words'): |
|
r.set('words', pickle.dumps(words)) |
|
else: |
|
existing_words = pickle.loads(r.get('words')) |
|
words = sorted(list(set(existing_words + words))) |
|
r.set('words', pickle.dumps(words)) |
|
if not r.exists('classes'): |
|
r.set('classes', pickle.dumps(classes)) |
|
else: |
|
existing_classes = pickle.loads(r.get('classes')) |
|
classes = sorted(list(set(existing_classes + classes))) |
|
r.set('classes', pickle.dumps(classes)) |
|
|
|
save_model(model, 'chatbot_model.h5') |
|
|
|
print("Data and model saved. Re-training...") |
|
|
|
def start_training_loop(): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
loop.run_until_complete(train_and_save_model()) |
|
|
|
class ChatMessage(BaseModel): |
|
message: str |
|
|
|
@app.post("/chat") |
|
async def chat(message: ChatMessage): |
|
words = pickle.loads(r.get('words')) |
|
classes = pickle.loads(r.get('classes')) |
|
model = load_model('chatbot_model.h5') |
|
|
|
sentence_words = nltk.word_tokenize(message.message) |
|
sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words] |
|
|
|
bag = [0] * len(words) |
|
for s in sentence_words: |
|
for i, w in enumerate(words): |
|
if w == s: |
|
bag[i] = 1 |
|
|
|
p = model.predict(np.array([bag]))[0] |
|
ERROR_THRESHOLD = 0.25 |
|
results = [[i, p] for i, p in enumerate(p) if p > ERROR_THRESHOLD] |
|
results.sort(key=lambda x: x[1], reverse=True) |
|
return_list = [] |
|
for i, p in results: |
|
return_list.append({"intent": classes[i], "probability": str(p)}) |
|
|
|
r.rpush('user_questions', message.message) |
|
|
|
return return_list |
|
|
|
@app.post("/tag") |
|
async def tag_question(question: str, tag: str): |
|
r.set(f"tag:{question}", tag) |
|
return {"message": "Tag saved"} |
|
|
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Chatbot</title> |
|
<style> |
|
body { |
|
font-family: sans-serif; |
|
background-color: #f4f4f4; |
|
margin: 0; |
|
padding: 0; |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
min-height: 100vh; |
|
} |
|
|
|
#container { |
|
background-color: #fff; |
|
border-radius: 5px; |
|
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); |
|
padding: 30px; |
|
width: 80%; |
|
max-width: 600px; |
|
} |
|
|
|
h1 { |
|
text-align: center; |
|
margin-bottom: 20px; |
|
color: #333; |
|
} |
|
|
|
#chatbox { |
|
height: 300px; |
|
overflow-y: auto; |
|
padding: 10px; |
|
border: 1px solid #ccc; |
|
border-radius: 5px; |
|
margin-bottom: 10px; |
|
} |
|
|
|
#chatbox p { |
|
margin: 5px 0; |
|
} |
|
|
|
#user_input { |
|
width: 100%; |
|
padding: 10px; |
|
border: 1px solid #ccc; |
|
border-radius: 5px; |
|
margin-bottom: 10px; |
|
box-sizing: border-box; |
|
} |
|
|
|
button { |
|
background-color: #4CAF50; |
|
color: white; |
|
padding: 10px 20px; |
|
border: none; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div id="container"> |
|
<h1>Chatbot</h1> |
|
<div id="chatbox"></div> |
|
<input type="text" id="user_input" placeholder="Type your message..."> |
|
<button onclick="sendMessage()">Send</button> |
|
</div> |
|
|
|
<script> |
|
function sendMessage() { |
|
let userInput = document.getElementById('user_input').value; |
|
document.getElementById('user_input').value = ''; |
|
|
|
fetch('/chat', { |
|
method: 'POST', |
|
headers: {'Content-Type': 'application/json'}, |
|
body: JSON.stringify({"message": userInput}) |
|
}) |
|
.then(response => response.json()) |
|
.then(data => { |
|
let chatbox = document.getElementById('chatbox'); |
|
chatbox.innerHTML += '<p><b>You:</b> ' + userInput + '</p>'; |
|
data.forEach(item => { |
|
chatbox.innerHTML += '<p><b>Bot:</b> ' + item.intent + ' (Probability: ' + item.probability + ')</p>'; |
|
}); |
|
}); |
|
} |
|
</script> |
|
|
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(): |
|
return html_code |
|
|
|
if __name__ == "__main__": |
|
initialize_redis() |
|
with concurrent.futures.ProcessPoolExecutor() as executor: |
|
executor.submit(start_training_loop) |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |