Updates

New open-source models and ToDoList will be listed on https://github.com/DunZhang/Stella/blob/main/news_and_todo.md.

You can also find these models on my homepage.

Introduction

The models are trained based on Alibaba-NLP/gte-large-en-v1.5 and Alibaba-NLP/gte-Qwen2-1.5B-instruct. Thanks for their contributions!

We simplify usage of prompts, providing two prompts for most general tasks, one is for s2p, another one is for s2s.

Prompt of s2p task(e.g. retrieve task):

Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: {query}

Prompt of s2s task(e.g. semantic textual similarity task):

Instruct: Retrieve semantically similar text.\nQuery: {query}

The models are finally trained by MRL, so they have multiple dimensions: 512, 768, 1024, 2048, 4096, 6144 and 8192.

The higher the dimension, the better the performance. Generally speaking, 1024d is good enough. The MTEB score of 1024d is only 0.001 lower than 8192d.

Model directory structure

The model directory structure is very simple, it is a standard SentenceTransformer directory with a series of 2_Dense_{dims} folders, where dims represents the final vector dimension.

For example, the 2_Dense_256 folder stores Linear weights that convert vector dimensions to 256 dimensions. Please refer to the following chapters for specific instructions on how to use them.

Usage

You can use SentenceTransformers or transformers library to encode text.

Sentence Transformers

from sentence_transformers import SentenceTransformer

# This model supports two prompts: "s2p_query" and "s2s_query" for sentence-to-passage and sentence-to-sentence tasks, respectively.
# They are defined in `config_sentence_transformers.json`
query_prompt_name = "s2p_query"
queries = [
    "What are some ways to reduce stress?",
    "What are the benefits of drinking green tea?",
]
# docs do not need any prompts
docs = [
    "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.",
    "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.",
]

# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json` to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` !
# on gpu
model = SentenceTransformer("dunzhang/stella_en_400M_v5", trust_remote_code=True).cuda()
# you can also use this model without the features of `use_memory_efficient_attention` and `unpad_inputs`. It can be worked in CPU.
# model = SentenceTransformer(
#     "dunzhang/stella_en_400M_v5",
#     trust_remote_code=True,
#     device="cpu",
#     config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}
# )
query_embeddings = model.encode(queries, prompt_name=query_prompt_name)
doc_embeddings = model.encode(docs)
print(query_embeddings.shape, doc_embeddings.shape)
# (2, 1024) (2, 1024)

similarities = model.similarity(query_embeddings, doc_embeddings)
print(similarities)
# tensor([[0.8398, 0.2990],
#         [0.3282, 0.8095]])

Transformers

import os
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize

query_prompt = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "
queries = [
    "What are some ways to reduce stress?",
    "What are the benefits of drinking green tea?",
]
queries = [query_prompt + query for query in queries]
# docs do not need any prompts
docs = [
    "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.",
    "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.",
]

# The path of your model after cloning it
model_dir = "{Your MODEL_PATH}"

vector_dim = 1024
vector_linear_directory = f"2_Dense_{vector_dim}"
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda().eval()
# you can also use this model without the features of `use_memory_efficient_attention` and `unpad_inputs`. It can be worked in CPU.
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,use_memory_efficient_attention=False,unpad_inputs=False).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
vector_linear = torch.nn.Linear(in_features=model.config.hidden_size, out_features=vector_dim)
vector_linear_dict = {
    k.replace("linear.", ""): v for k, v in
    torch.load(os.path.join(model_dir, f"{vector_linear_directory}/pytorch_model.bin")).items()
}
vector_linear.load_state_dict(vector_linear_dict)
vector_linear.cuda()

# Embed the queries
with torch.no_grad():
    input_data = tokenizer(queries, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    input_data = {k: v.cuda() for k, v in input_data.items()}
    attention_mask = input_data["attention_mask"]
    last_hidden_state = model(**input_data)[0]
    last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
    query_vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    query_vectors = normalize(vector_linear(query_vectors).cpu().numpy())

# Embed the documents
with torch.no_grad():
    input_data = tokenizer(docs, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    input_data = {k: v.cuda() for k, v in input_data.items()}
    attention_mask = input_data["attention_mask"]
    last_hidden_state = model(**input_data)[0]
    last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
    docs_vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    docs_vectors = normalize(vector_linear(docs_vectors).cpu().numpy())

print(query_vectors.shape, docs_vectors.shape)
# (2, 1024) (2, 1024)

similarities = query_vectors @ docs_vectors.T
print(similarities)
# [[0.8397531  0.29900077]
#  [0.32818374 0.80954516]]

FAQ

Q: The details of training?

A: The training method and datasets will be released in the future. (specific time unknown, may be provided in a paper)

Q: How to choose a suitable prompt for my own task?

A: In most cases, please use the s2p and s2s prompts. These two prompts account for the vast majority of the training data.

Q: How to reproduce MTEB results?

A: Please use evaluation scripts in Alibaba-NLP/gte-Qwen2-1.5B-instruct or intfloat/e5-mistral-7b-instruct

Q: Why each dimension has a linear weight?

A: MRL has multiple training methods, we choose this method which has the best performance.

Q: What is the sequence length of models?

A: 512 is recommended, in our experiments, almost all models perform poorly on specialized long text retrieval datasets. Besides, the model is trained on datasets of 512 length. This may be an optimization term.

If you have any questions, please start a discussion on community.

Downloads last month
550
Safetensors
Model size
435M params
Tensor type
F32
·
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.

Evaluation results