ptrdvn commited on
Commit
6ed382f
·
verified ·
1 Parent(s): e1607fd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +114 -0
README.md CHANGED
@@ -110,6 +110,120 @@ The training data for this model can be found at [lightblue/reranker_continuous_
110
 
111
  Trained on data in over 95 languages, this model is applicable to a broad range of use cases.
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # Evaluation
114
 
115
  We perform an evaluation on 9 datasets from the [BEIR benchmark](https://github.com/beir-cellar/beir) that none of the evaluated models have been trained upon (to our knowledge).
 
110
 
111
  Trained on data in over 95 languages, this model is applicable to a broad range of use cases.
112
 
113
+ This model has three main benefits over comparable rerankers.
114
+ 1. It has shown slightly higher performance on evaluation benchmarks.
115
+ 2. It has been trained on more languages than any previous model.
116
+ 3. It is a simple Causal LM model trained to output a string between "1" and "7".
117
+
118
+ This last point means that this model can be used natively with many widely available inference packages, including vLLM and LMDeploy.
119
+ This in turns allows our reranker to benefit from improvements to inference as and when these packages release them.
120
+
121
+ # How to use
122
+
123
+ #### vLLM
124
+
125
+ Install [vLLM](https://github.com/vllm-project/vllm/) using `pip install vllm`.
126
+
127
+ ```python
128
+ from vllm import LLM, SamplingParams
129
+ import numpy as np
130
+
131
+ def make_reranker_input(t, q):
132
+ return f"<<<Query>>>\n{q}\n\n<<<Context>>>\n{t}"
133
+
134
+ def make_reranker_training_datum(context, question):
135
+ system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
136
+
137
+ return [
138
+ {"role": "system", "content": system_message},
139
+ {"role": "user", "content": make_reranker_input(context, question)},
140
+ ]
141
+
142
+ def get_prob(logprob_dict, tok_id):
143
+ return np.exp(logprob_dict[tok_id].logprob) if tok_id in logprob_dict.keys() else 0
144
+
145
+ llm = LLM("lightblue/lb-reranker-v1.0")
146
+ sampling_params = SamplingParams(temperature=0.0, logprobs=14, max_tokens=1)
147
+ tok = llm.llm_engine.tokenizer.tokenizer
148
+ idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)]
149
+
150
+ query_texts = [
151
+ ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
152
+ ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
153
+ ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
154
+ ]
155
+
156
+ chats = [make_reranker_training_datum(c, q) for q, c in query_texts]
157
+ responses = llm.chat(chats, sampling_params)
158
+ probs = np.array([[get_prob(r.outputs[0].logprobs[0], y) for y in idx_tokens] for r in responses])
159
+
160
+ N = probs.shape[1]
161
+ M = probs.shape[0]
162
+ idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N)
163
+
164
+ expected_vals = (probs * idxs).sum(axis=1)
165
+ print(expected_vals)
166
+ # [6.66570732 1.86686378 1.01102923]
167
+ ```
168
+
169
+ #### LMDeploy
170
+
171
+ Install [LMDeploy](https://github.com/InternLM/lmdeploy) using `pip install lmdeploy`.
172
+
173
+ ```python
174
+ # Un-comment this if running in a Jupyter notebook, Colab etc.
175
+ # import nest_asyncio
176
+ # nest_asyncio.apply()
177
+
178
+ from lmdeploy import GenerationConfig, ChatTemplateConfig, pipeline
179
+ import numpy as np
180
+
181
+ def make_reranker_input(t, q):
182
+ return f"<<<Query>>>\n{q}\n\n<<<Context>>>\n{t}"
183
+
184
+ def make_reranker_training_datum(context, question):
185
+ system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
186
+
187
+ return [
188
+ {"role": "system", "content": system_message},
189
+ {"role": "user", "content": make_reranker_input(context, question)},
190
+ ]
191
+
192
+ def get_prob(logprob_dict, tok_id):
193
+ return np.exp(logprob_dict[tok_id]) if tok_id in logprob_dict.keys() else 0
194
+
195
+ pipe = pipeline(
196
+ "lightblue/lb-reranker-v1.0",
197
+ chat_template_config=ChatTemplateConfig(
198
+ model_name='qwen2d5',
199
+ capability='chat'
200
+ )
201
+ )
202
+ tok = pipe.tokenizer.model
203
+ idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)]
204
+
205
+ query_texts = [
206
+ ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
207
+ ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
208
+ ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
209
+ ]
210
+
211
+ chats = [make_reranker_training_datum(c, q) for q, c in query_texts]
212
+ responses = pipe(
213
+ chats,
214
+ gen_config=GenerationConfig(temperature=0.0, logprobs=14, max_new_tokens=1)
215
+ )
216
+ probs = np.array([[get_prob(r.logprobs[0], y) for y in idx_tokens] for r in responses])
217
+
218
+ N = probs.shape[1]
219
+ M = probs.shape[0]
220
+ idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N)
221
+
222
+ expected_vals = (probs * idxs).sum(axis=1)
223
+ print(expected_vals)
224
+ # [7. 2. 1.]
225
+ ```
226
+
227
  # Evaluation
228
 
229
  We perform an evaluation on 9 datasets from the [BEIR benchmark](https://github.com/beir-cellar/beir) that none of the evaluated models have been trained upon (to our knowledge).