Edit model card

RoSEtta

RoSEtta (RoFormer-based Sentence Encoder through Distillation) is a general Japanese text embedding model, excelling in retrieval tasks. It has a maximum sequence length of 1024, allowing for input of long sentences. It can run on a CPU and is designed to measure semantic similarity between sentences, as well as to function as a retrieval system for searching passages based on queries.

Key features:

  • Use RoPE (Rotary Position Embedding)
  • Maximum sequence length of 1024 tokens
  • Distilled from large sentence embedding models
  • Specialized for retrieval tasks

During inference, the prefix "query: " or "passage: " is required. Please check the Usage section for details.

Model Description

This model is based on RoFormer architecture. After pre-training using MLM loss, weakly supervised learning was performed. Additionally, further training was conducted through distillation using several large embedding models and multi-stage contrastive learning (like GLuCoSE v2).

  • Maximum Sequence Length: 1024 tokens
  • Output Dimensionality: 768 tokens
  • Similarity Function: Cosine Similarity

Usage

Direct Usage (Sentence Transformers)

You can perform inference using SentenceTransformer with the following code:

from sentence_transformers import SentenceTransformer
import torch.nn.functional as F

# Download from the 🤗 Hub
# The argument "trust_remote_code=True" is required to load the model
model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True)

# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
sentences = [
    'query: PKSHAはどんな会社ですか?',
    'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
    'query: 日本で一番高い山は?',
    'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。',
]
embeddings = model.encode(sentences,convert_to_tensor=True)
print(embeddings.shape)
# [4, 768]

# Get the similarity scores for the embeddings
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
print(similarities)
# [[1.0000, 0.5910, 0.4332, 0.5421],
# [0.5910, 1.0000, 0.4977, 0.6969],
# [0.4332, 0.4977, 1.0000, 0.7475],
# [0.5421, 0.6969, 0.7475, 1.0000]]

Direct Usage (Transformers)

You can perform inference using Transformers with the following code:

import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

def mean_pooling(last_hidden_states: Tensor,attention_mask: Tensor) -> Tensor:
    emb = last_hidden_states * attention_mask.unsqueeze(-1)
    emb = emb.sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(-1)
    return emb

# Download from the 🤗 Hub
tokenizer = AutoTokenizer.from_pretrained("pkshatech/RoSEtta-base-ja")
# The argument "trust_remote_code=True" is required to load the model
model = AutoModel.from_pretrained("pkshatech/RoSEtta-base-ja",trust_remote_code=True)

# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
sentences = [
    'query: PKSHAはどんな会社ですか?',
    'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
    'query: 日本で一番高い山は?',
    'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。',
]

# Tokenize the input texts
batch_dict = tokenizer(sentences, max_length=1024, padding=True, truncation=True, return_tensors='pt')

outputs = model(**batch_dict)
embeddings = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask'])
print(embeddings.shape)
# [4, 768]

# Get the similarity scores for the embeddings
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
print(similarities)
# [[1.0000, 0.5910, 0.4332, 0.5421],
# [0.5910, 1.0000, 0.4977, 0.6969],
# [0.4332, 0.4977, 1.0000, 0.7475],
# [0.5421, 0.6969, 0.7475, 1.0000]]

Training Details

The fine-tuning of RoSEtta is carried out through the following steps:

Step 1: Pre-training

Step 2: Weakly supervised learning

Step 3: Ensemble distillation

Step 4: Contrastive learning

  • Triplets were created from JSNLI, MNLI, PAWS-X, JSeM and Mr.TyDi and used for training.
  • This training aimed to improve the overall performance as a sentence embedding model.

Step 5: Search-specific contrastive learning

Benchmarks

Retrieval

Evaluated with MIRACL-ja, JQARA , JaCWIR and MLDR-ja.

Model Size MIRACL
Recall@5
JQaRA
nDCG@10
JaCWIR
MAP@10
MLDR
nDCG@10
intfloat/multilingual-e5-large 0.6B 89.2 55.4 87.6 29.8
cl-nagoya/ruri-large 0.3B 78.7 62.4 85.0 37.5
intfloat/multilingual-e5-base 0.3B 84.2 47.2 85.3 25.4
cl-nagoya/ruri-base 0.1B 74.3 58.1 84.6 35.3
pkshatech/GLuCoSE-base-ja 0.1B 53.3 30.8 68.6 25.2
RoSEtta 0.2B 79.3 57.7 83.8 32.3

Note: Results for OpenAI small embeddings in JQARA and JaCWIR are quoted from the JQARA and JaCWIR.

JMTEB

Evaluated with JMTEB.

The average score is macro-average.

Model Size Avg. Retrieval STS Classification Reranking Clustering PairClassification
OpenAI/text-embedding-3-small - 69.18 66.39 79.46 73.06 92.92 51.06 62.27
OpenAI/text-embedding-3-large - 74.05 74.48 82.52 77.58 93.58 53.32 62.35
intfloat/multilingual-e5-large 0.6B 70.90 70.98 79.70 72.89 92.96 51.24 62.15
cl-nagoya/ruri-large 0.3B 73.31 73.02 83.13 77.43 92.99 51.82 62.29
intfloat/multilingual-e5-base 0.3B 68.61 68.21 79.84 69.30 92.85 48.26 62.26
cl-nagoya/ruri-base 0.1B 71.91 69.82 82.87 75.58 92.91 54.16 62.38
pkshatech/GLuCoSE-base-ja 0.1B 67.29 59.02 78.71 76.82 91.90 49.78 66.39
RoSEtta 0.2B 72.45 73.21 81.39 72.41 92.69 53.23 61.74

Authors

Chihiro Yano, Mocho Go, Hideyuki Tachibana, Hiroto Takegawa, Yotaro Watanabe

License

This model is published under the Apache License, Version 2.0.

Downloads last month
2,598
Safetensors
Model size
190M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train pkshatech/RoSEtta-base-ja

Space using pkshatech/RoSEtta-base-ja 1