Gopikanth123 commited on
Commit
00383dc
·
verified ·
1 Parent(s): 064c178

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -41
main.py CHANGED
@@ -1,13 +1,12 @@
1
- import os
2
- import shutil
3
- from flask import Flask, render_template, request, jsonify
4
- from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
5
- from llama_index.llms.huggingface import HuggingFaceInferenceAPI
6
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
- from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
-
11
  # Ensure HF_TOKEN is set
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  if not HF_TOKEN:
@@ -28,15 +27,13 @@ Settings.llm = HuggingFaceInferenceAPI(
28
  max_new_tokens=512,
29
  generate_kwargs={"temperature": 0.1},
30
  )
31
- # Settings.embed_model = HuggingFaceEmbedding(
32
- # model_name="BAAI/bge-small-en-v1.5"
33
- # )
34
- # Replace the embedding model with XLM-R
35
  Settings.embed_model = HuggingFaceEmbedding(
36
- model_name="xlm-roberta-base" # XLM-RoBERTa model for multilingual support
37
  )
38
 
39
- # Configure tokenizer and model if required
40
  tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
41
  model = AutoModel.from_pretrained("xlm-roberta-base")
42
 
@@ -49,35 +46,38 @@ os.makedirs(PERSIST_DIR, exist_ok=True)
49
  chat_history = []
50
  current_chat_history = []
51
 
 
52
  def data_ingestion_from_directory():
53
- # Clear previous data by removing the persist directory
54
  if os.path.exists(PERSIST_DIR):
55
- shutil.rmtree(PERSIST_DIR) # Remove the persist directory and all its contents
56
 
57
- # Recreate the persist directory after removal
58
  os.makedirs(PERSIST_DIR, exist_ok=True)
59
-
60
- # Load new documents from the directory
61
  new_documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
62
-
63
- # Create a new index with the new documents
64
  index = VectorStoreIndex.from_documents(new_documents)
65
-
66
- # Persist the new index
67
  index.storage_context.persist(persist_dir=PERSIST_DIR)
68
 
69
- def handle_query(query):
 
70
  context_str = ""
71
 
72
  # Build context from current chat history
73
  for past_query, response in reversed(current_chat_history):
74
  if past_query.strip():
75
  context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
 
 
 
 
 
 
 
 
76
 
77
  chat_text_qa_msgs = [
78
  (
79
  "user",
80
- """You are the Taj Hotel chatbot, Taj Hotel Helper.
 
81
 
82
  **Your Role:**
83
  - Respond accurately and concisely in the user's preferred language (English, Telugu, or Hindi).
@@ -87,7 +87,7 @@ def handle_query(query):
87
  - **Context:**
88
  {context_str}
89
  - **User's Question:**
90
- {query_str}
91
 
92
  **Response Guidelines:**
93
  1. **Language Adaptation:** Respond in the language of the question (English, Telugu, or Hindi).
@@ -98,28 +98,21 @@ def handle_query(query):
98
  5. **Actionable Help:** Offer suggestions or alternative steps to guide the user where applicable.
99
 
100
  **Response:** [Your concise response here]
101
- """.format(context_str=context_str, query_str=query)
102
  )
103
  ]
104
 
105
-
106
-
107
- text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
108
 
 
109
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
110
  index = load_index_from_storage(storage_context)
111
- # context_str = ""
112
-
113
- # # Build context from current chat history
114
- # for past_query, response in reversed(current_chat_history):
115
- # if past_query.strip():
116
- # context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
117
 
118
  query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
119
  print(f"Querying: {query}")
120
  answer = query_engine.query(query)
121
 
122
- # Extracting the response
123
  if hasattr(answer, 'response'):
124
  response = answer.response
125
  elif isinstance(answer, dict) and 'response' in answer:
@@ -137,10 +130,10 @@ app = Flask(__name__)
137
  data_ingestion_from_directory()
138
 
139
  # Generate Response
140
- def generate_response(query):
141
  try:
142
  # Call the handle_query function to get the response
143
- bot_response = handle_query(query)
144
  return bot_response
145
  except Exception as e:
146
  return f"Error fetching the response: {str(e)}"
@@ -155,13 +148,17 @@ def index():
155
  def chat():
156
  try:
157
  user_message = request.json.get("message")
 
158
  if not user_message:
159
  return jsonify({"response": "Please say something!"})
160
 
161
- bot_response = generate_response(user_message)
 
 
 
162
  return jsonify({"response": bot_response})
163
  except Exception as e:
164
  return jsonify({"response": f"An error occurred: {str(e)}"})
165
 
166
  if __name__ == '__main__':
167
- app.run(debug=True)
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import os
3
+ import shutil
4
+ from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
5
+ from llama_index.llms.huggingface import HuggingFaceInferenceAPI
6
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
+ from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer, AutoModel
9
 
 
10
  # Ensure HF_TOKEN is set
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if not HF_TOKEN:
 
27
  max_new_tokens=512,
28
  generate_kwargs={"temperature": 0.1},
29
  )
30
+
31
+ # Configure embedding model (XLM-RoBERTa model for multilingual support)
 
 
32
  Settings.embed_model = HuggingFaceEmbedding(
33
+ model_name="xlm-roberta-base" # Multilingual support
34
  )
35
 
36
+ # Configure tokenizer and model for multilingual responses
37
  tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
38
  model = AutoModel.from_pretrained("xlm-roberta-base")
39
 
 
46
  chat_history = []
47
  current_chat_history = []
48
 
49
+ # Data ingestion function
50
  def data_ingestion_from_directory():
 
51
  if os.path.exists(PERSIST_DIR):
52
+ shutil.rmtree(PERSIST_DIR) # Remove the persist directory and its contents
53
 
 
54
  os.makedirs(PERSIST_DIR, exist_ok=True)
 
 
55
  new_documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
 
 
56
  index = VectorStoreIndex.from_documents(new_documents)
 
 
57
  index.storage_context.persist(persist_dir=PERSIST_DIR)
58
 
59
+ # Function to handle the query and provide a response
60
+ def handle_query(query, selected_language):
61
  context_str = ""
62
 
63
  # Build context from current chat history
64
  for past_query, response in reversed(current_chat_history):
65
  if past_query.strip():
66
  context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
67
+
68
+ # Define the response template based on selected language
69
+ if selected_language == 'telugu':
70
+ language_prompt = "మీరు తాజ్ హోటల్ చాట్‌బాట్, తాజ్ హోటల్ సహాయకుడు."
71
+ elif selected_language == 'hindi':
72
+ language_prompt = "आप ताज होटल चैटबोट हैं, ताज होटल सहायक।"
73
+ else:
74
+ language_prompt = "You are the Taj Hotel chatbot, Taj Hotel Helper."
75
 
76
  chat_text_qa_msgs = [
77
  (
78
  "user",
79
+ f"""
80
+ {language_prompt}
81
 
82
  **Your Role:**
83
  - Respond accurately and concisely in the user's preferred language (English, Telugu, or Hindi).
 
87
  - **Context:**
88
  {context_str}
89
  - **User's Question:**
90
+ {query}
91
 
92
  **Response Guidelines:**
93
  1. **Language Adaptation:** Respond in the language of the question (English, Telugu, or Hindi).
 
98
  5. **Actionable Help:** Offer suggestions or alternative steps to guide the user where applicable.
99
 
100
  **Response:** [Your concise response here]
101
+ """
102
  )
103
  ]
104
 
105
+ text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
 
 
106
 
107
+ # Load the index for querying
108
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
109
  index = load_index_from_storage(storage_context)
 
 
 
 
 
 
110
 
111
  query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
112
  print(f"Querying: {query}")
113
  answer = query_engine.query(query)
114
 
115
+ # Extracting the response
116
  if hasattr(answer, 'response'):
117
  response = answer.response
118
  elif isinstance(answer, dict) and 'response' in answer:
 
130
  data_ingestion_from_directory()
131
 
132
  # Generate Response
133
+ def generate_response(query, language):
134
  try:
135
  # Call the handle_query function to get the response
136
+ bot_response = handle_query(query, language)
137
  return bot_response
138
  except Exception as e:
139
  return f"Error fetching the response: {str(e)}"
 
148
  def chat():
149
  try:
150
  user_message = request.json.get("message")
151
+ selected_language = request.json.get("language") # Get selected language from the request
152
  if not user_message:
153
  return jsonify({"response": "Please say something!"})
154
 
155
+ if selected_language not in ['english', 'telugu', 'hindi']:
156
+ return jsonify({"response": "Invalid language selected."})
157
+
158
+ bot_response = generate_response(user_message, selected_language)
159
  return jsonify({"response": bot_response})
160
  except Exception as e:
161
  return jsonify({"response": f"An error occurred: {str(e)}"})
162
 
163
  if __name__ == '__main__':
164
+ app.run(debug=True)