zihanliu commited on
Commit
169f4c7
·
verified ·
1 Parent(s): dd61e9f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -7
README.md CHANGED
@@ -107,10 +107,7 @@ contexts = [
107
 
108
  ## convert query into a format as follows:
109
  ## user: {user}\nagent: {agent}\nuser: {user}
110
- formatted_query = ""
111
- for turn in query:
112
- formatted_query += turn['role'] + ": " + turn['content'] + "\n"
113
- formatted_query = formatted_query.strip()
114
 
115
  ## get query and context embeddings
116
  query_input = tokenizer(formatted_query, return_tensors='pt')
@@ -118,9 +115,11 @@ ctx_input = tokenizer(contexts, padding=True, return_tensors='pt')
118
  query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
119
  ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
120
 
121
- # Compute similarity scores using dot product
122
- score1 = query_emb @ ctx_emb[0]
123
- score2 = query_emb @ ctx_emb[1]
 
 
124
  ```
125
 
126
  ## License
 
107
 
108
  ## convert query into a format as follows:
109
  ## user: {user}\nagent: {agent}\nuser: {user}
110
+ formatted_query = '\n'.join([turn['role'] + ": " + turn['content'] for turn in messages]).strip()
 
 
 
111
 
112
  ## get query and context embeddings
113
  query_input = tokenizer(formatted_query, return_tensors='pt')
 
115
  query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
116
  ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
117
 
118
+ ## Compute similarity scores using dot product
119
+ similarities = query_emb.matmul(ctx_emb.transpose(0, 1)) # (1, num_ctx)
120
+
121
+ ## rank the similarity (from highest to lowest)
122
+ ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (1, num_ctx)
123
  ```
124
 
125
  ## License