capradeepgujaran commited on
Commit
3303167
1 Parent(s): ea34aa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -128,17 +128,19 @@ def process_upload(api_key, files):
128
  return f"No valid documents were indexed. Errors: {'; '.join(error_messages)}", None
129
 
130
  def calculate_similarity(response, ground_truth):
 
131
  response_embedding = sentence_model.encode(response, convert_to_tensor=True)
132
  truth_embedding = sentence_model.encode(ground_truth, convert_to_tensor=True)
133
 
134
- # Normalize the embeddings
135
- response_embedding = response_embedding / np.linalg.norm(response_embedding)
136
- truth_embedding = truth_embedding / np.linalg.norm(truth_embedding)
137
 
138
  # Calculate cosine similarity using sklearn's cosine_similarity function
139
  similarity = cosine_similarity(response_embedding.reshape(1, -1), truth_embedding.reshape(1, -1))[0][0]
140
  return similarity * 100 # Convert to percentage
141
 
 
142
  def query_app(query, model_name, use_similarity_check, openai_api_key):
143
  global vector_index, query_log
144
 
 
128
  return f"No valid documents were indexed. Errors: {'; '.join(error_messages)}", None
129
 
130
  def calculate_similarity(response, ground_truth):
131
+ # Encode the response and ground truth
132
  response_embedding = sentence_model.encode(response, convert_to_tensor=True)
133
  truth_embedding = sentence_model.encode(ground_truth, convert_to_tensor=True)
134
 
135
+ # Explicitly normalize the embeddings (should result in unit vectors)
136
+ response_embedding = response_embedding / response_embedding.norm(p=2)
137
+ truth_embedding = truth_embedding / truth_embedding.norm(p=2)
138
 
139
  # Calculate cosine similarity using sklearn's cosine_similarity function
140
  similarity = cosine_similarity(response_embedding.reshape(1, -1), truth_embedding.reshape(1, -1))[0][0]
141
  return similarity * 100 # Convert to percentage
142
 
143
+
144
  def query_app(query, model_name, use_similarity_check, openai_api_key):
145
  global vector_index, query_log
146