Tonic commited on
Commit
e7481b0
1 Parent(s): 930288d

escape special characters

Browse files
Files changed (2) hide show
  1. app.py +22 -15
  2. utils.py +34 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # main.py
2
  import spaces
3
  from torch.nn import DataParallel
4
  from torch import Tensor
@@ -17,7 +17,7 @@ import gradio as gr
17
  import torch
18
  import torch.nn.functional as F
19
  from dotenv import load_dotenv
20
- from utils import load_env_variables, parse_and_route
21
  from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt
22
 
23
 
@@ -49,12 +49,12 @@ class EmbeddingGenerator:
49
 
50
  @spaces.GPU
51
  def compute_embeddings(self, input_text: str):
52
- # Get the intention
53
  intention_completion = self.intention_client.chat.completions.create(
54
  model="yi-large",
55
  messages=[
56
- {"role": "system", "content": intention_prompt},
57
- {"role": "user", "content": input_text}
58
  ]
59
  )
60
  intention_output = intention_completion.choices[0].message['content']
@@ -71,14 +71,14 @@ class EmbeddingGenerator:
71
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
72
 
73
  query_prefix = f"Instruct: {task_description}\nQuery: "
74
- queries = [input_text]
75
 
76
  # Get the metadata
77
  metadata_completion = self.intention_client.chat.completions.create(
78
  model="yi-large",
79
  messages=[
80
- {"role": "system", "content": metadata_prompt},
81
- {"role": "user", "content": input_text}
82
  ]
83
  )
84
  metadata_output = metadata_completion.choices[0].message['content']
@@ -93,12 +93,9 @@ class EmbeddingGenerator:
93
  # Normalize embeddings
94
  query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
95
  embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
96
-
97
- # Include metadata in the embeddings
98
- embeddings_with_metadata = [{"embedding": emb, "metadata": metadata} for emb in embeddings_list]
99
 
100
  self.clear_cuda_cache()
101
- return embeddings_with_metadata
102
 
103
  def extract_metadata(self, metadata_output: str):
104
  # Regex pattern to extract key-value pairs
@@ -143,8 +140,18 @@ def add_documents_to_chroma(client, collection, documents: list, embedding_funct
143
  )
144
 
145
  def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction):
 
 
 
 
146
  db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function)
147
- result_docs = db.similarity_search(query_text)
 
 
 
 
 
 
148
  return result_docs
149
 
150
 
@@ -164,13 +171,13 @@ def respond(
164
  top_p,
165
  ):
166
  retrieved_text = query_documents(message)
167
- messages = [{"role": "system", "content": system_message}]
168
  for val in history:
169
  if val[0]:
170
  messages.append({"role": "user", "content": val[0]})
171
  if val[1]:
172
  messages.append({"role": "assistant", "content": val[1]})
173
- messages.append({"role": "user", "content": f"{retrieved_text}\n\n{message}"})
174
  response = ""
175
  for message in intention_client.chat_completion(
176
  messages,
 
1
+ # app.py
2
  import spaces
3
  from torch.nn import DataParallel
4
  from torch import Tensor
 
17
  import torch
18
  import torch.nn.functional as F
19
  from dotenv import load_dotenv
20
+ from utils import load_env_variables, parse_and_route , escape_special_characters
21
  from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt
22
 
23
 
 
49
 
50
  @spaces.GPU
51
  def compute_embeddings(self, input_text: str):
52
+ escaped_input_text = escape_special_characters(input_text)
53
  intention_completion = self.intention_client.chat.completions.create(
54
  model="yi-large",
55
  messages=[
56
+ {"role": "system", "content": escape_special_characters(intention_prompt)},
57
+ {"role": "user", "content": escaped_input_text}
58
  ]
59
  )
60
  intention_output = intention_completion.choices[0].message['content']
 
71
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
72
 
73
  query_prefix = f"Instruct: {task_description}\nQuery: "
74
+ queries = [escaped_input_text]
75
 
76
  # Get the metadata
77
  metadata_completion = self.intention_client.chat.completions.create(
78
  model="yi-large",
79
  messages=[
80
+ {"role": "system", "content": escape_special_characters(metadata_prompt)},
81
+ {"role": "user", "content": escaped_input_text}
82
  ]
83
  )
84
  metadata_output = metadata_completion.choices[0].message['content']
 
93
  # Normalize embeddings
94
  query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
95
  embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
 
 
 
96
 
97
  self.clear_cuda_cache()
98
+ return embeddings_list, metadata
99
 
100
  def extract_metadata(self, metadata_output: str):
101
  # Regex pattern to extract key-value pairs
 
140
  )
141
 
142
  def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction):
143
+ # Compute query embeddings and metadata
144
+ query_embeddings, query_metadata = embedding_function.embedding_generator.compute_embeddings(query_text)
145
+
146
+ # Initialize Chroma with the collection
147
  db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function)
148
+
149
+ # Perform similarity search using the query embeddings and metadata
150
+ result_docs = db.similarity_search(
151
+ query_embeddings=query_embeddings,
152
+ query_metadata=query_metadata
153
+ )
154
+
155
  return result_docs
156
 
157
 
 
171
  top_p,
172
  ):
173
  retrieved_text = query_documents(message)
174
+ messages = [{"role": "system", "content": escape_special_characters(system_message)}]
175
  for val in history:
176
  if val[0]:
177
  messages.append({"role": "user", "content": val[0]})
178
  if val[1]:
179
  messages.append({"role": "assistant", "content": val[1]})
180
+ messages.append({"role": "user", "content": f"{retrieved_text}\n\n{escape_special_characters(message)}"})
181
  response = ""
182
  for message in intention_client.chat_completion(
183
  messages,
utils.py CHANGED
@@ -30,4 +30,37 @@ def parse_and_route(example_output: str):
30
  else:
31
  return {true_task: "Task description not found"}
32
  else:
33
- return "No true task found in the example output"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  else:
31
  return {true_task: "Task description not found"}
32
  else:
33
+ return "No true task found in the example output"
34
+
35
+ import json
36
+
37
+ def escape_special_characters(text: str) -> str:
38
+ """
39
+ Escapes special characters in the given text for JSON and cURL compatibility.
40
+ """
41
+ escaped_text = json.dumps(text)[1:-1]
42
+ curl_escaped_text = escaped_text.replace(" ", "\\ ")
43
+ curl_escaped_text = curl_escaped_text.replace("&", "\\&")
44
+ curl_escaped_text = curl_escaped_text.replace(";", "\\;")
45
+ curl_escaped_text = curl_escaped_text.replace("(", "\\(")
46
+ curl_escaped_text = curl_escaped_text.replace(")", "\\)")
47
+ curl_escaped_text = curl_escaped_text.replace("$", "\\$")
48
+ curl_escaped_text = curl_escaped_text.replace("`", "\\`")
49
+ curl_escaped_text = curl_escaped_text.replace("|", "\\|")
50
+ curl_escaped_text = curl_escaped_text.replace("*", "\\*")
51
+ curl_escaped_text = curl_escaped_text.replace("?", "\\?")
52
+ curl_escaped_text = curl_escaped_text.replace("<", "\\<")
53
+ curl_escaped_text = curl_escaped_text.replace(">", "\\>")
54
+ curl_escaped_text = curl_escaped_text.replace("!", "\\!")
55
+ curl_escaped_text = curl_escaped_text.replace("{", "\\{")
56
+ curl_escaped_text = curl_escaped_text.replace("}", "\\}")
57
+ curl_escaped_text = curl_escaped_text.replace("[", "\\[")
58
+ curl_escaped_text = curl_escaped_text.replace("]", "\\]")
59
+ curl_escaped_text = curl_escaped_text.replace("#", "\\#")
60
+ curl_escaped_text = curl_escaped_text.replace("%", "\\%")
61
+ curl_escaped_text = curl_escaped_text.replace("^", "\\^")
62
+ curl_escaped_text = curl_escaped_text.replace("=", "\\=")
63
+ curl_escaped_text = curl_escaped_text.replace("~", "\\~")
64
+ curl_escaped_text = curl_escaped_text.replace("'", "\\'")
65
+
66
+ return curl_escaped_text