from FlagEmbedding import FlagModel | |
model = FlagModel("openbmb/MiniCPM-Embedding-Light", | |
query_instruction_for_retrieval="Query: ", | |
pooling_method="mean", | |
trust_remote_code=True, | |
normalize_embeddings=True, | |
use_fp16=True) | |
# You can hack the __init__() method of the FlagEmbedding BaseEmbedder class to use flash_attention_2 for faster inference | |
# self.model = AutoModel.from_pretrained( | |
# model_name_or_path, | |
# trust_remote_code=trust_remote_code, | |
# cache_dir=cache_dir, | |
# # torch_dtype=torch.float16, # we need to add this line to use fp16 | |
# # attn_implementation="flash_attention_2", # we need to add this line to use flash_attention_2 | |
# ) | |
queries = ["中国的首都是哪里?"] # "What is the capital of China?" | |
passages = ["beijing", "shanghai"] # "北京", "上海" | |
embeddings_query = model.encode_queries(queries) | |
embeddings_doc = model.encode_corpus(passages) | |
scores = (embeddings_query @ embeddings_doc.T) | |
print(scores.tolist()) # [[0.40356746315956116, 0.36183440685272217]] |