hotchpotch's picture
Update README.md
6597b2f verified
|
raw
history blame
21.4 kB
---
language:
- ja
- en
license: mit
tags:
- sentence-transformers
- sentence-similarity
- feature-extraction
- generated_from_trainer
- dataset_size:16897699
- loss:MatryoshkaLoss
- loss:MultipleNegativesRankingLoss
datasets:
- sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1
- sentence-transformers/squad
- sentence-transformers/all-nli
- sentence-transformers/trivia-qa
- nthakur/swim-ir-monolingual
- sentence-transformers/miracl
- sentence-transformers/mr-tydi
- hotchpotch/sentence_transformer_japanese
library_name: sentence-transformers
---
以下の文章は、ブログ記事⭐️からの転載です。
# 100倍速で実用的な文章ベクトルを作れる、日本語 StaticEmbedding を公開
文章の密ベクトルは、情報検索・文章判別・類似文章抽出など、さまざまな用途に使うことができます。しかしながら最先端のTransformerモデルは小さいモデルでも、とりわけCPUでは遅く、変換速度が実用でないこともしばしばです。
しかしながら、先日公開されたTransformerモデル「ではない」 [StaticEmbedding](https://huggingface.co./blog/static-embeddings)は、例えば [intfloat/multilingual-e5-small](https://huggingface.co./intfloat/multilingual-e5-small) (以下mE5-small)とのベンチマーク比較では85%のスコアという実用できる性能で、かつCPUで動作時に126倍高速に文ベクトルを作成することができる、という驚きの速度です。
というわけで、早速日本語(と英語)で学習させたモデル sentence-embedding-japanese を作成し、公開しました。
- https://huggingface.co./hotchpotch/static-embedding-japanese
日本語の文章ベクトルの性能を評価する JMTEB の結果は以下です。確かに mE5-small には若干及ばないまでも、タスクによっては勝っていたりしますし、[他の日本語baseサイズbertモデルよりもスコアが高いこともある](https://github.com/sbintuitions/JMTEB/blob/main/leaderboard.md)ぐらい、最低限実用に達している性能が出ていますね。本当にそんなに性能出るのか実際に学習させてみるまで半信半疑でしたが、すごいですね。
| Model | Avg(micro) | Retrieval | STS | Classification | Reranking | Clustering | PairClassification |
| ---------------------------------------- | ---------- | --------- | ----- | -------------- | --------- | ---------- | ------------------ |
| text-embedding-3-small | 69.18 | 66.39 | 79.46 | 73.06 | 92.92 | 51.06 | 62.27 |
| multilingual-e5-small | 67.71 | 67.27 | 80.07 | 67.62 | 93.03 | 46.91 | 62.19 |
| **static-embedding-japanese** | 66.66 | **67.92** | **80.16** | **67.96** | 91.87 | 35.83 | **62.37** |
なお、StaticEmbedding 日本語モデル学習などの技術的なことは記事の後半に書いているので、興味がある方はどうぞ。
## 利用方法
利用は簡単、SentenceTransformer を使っていつもの方法で文ベクトルを作れます。今回はGPUを使わず、CPUで実行してみましょう。なお SentenceTransformer は 3.3.1 で試しています。
```
pip install "sentence-transformers>=3.3.1"
```
```python
from sentence_transformers import SentenceTransformer
model_name = "hotchpotch/static-embedding-japanese"
model = SentenceTransformer(model_name, device="cpu")
query = "美味しいラーメン屋に行きたい"
docs = [
"素敵なカフェが近所にあるよ。落ち着いた雰囲気でゆっくりできるし、窓際の席からは公園の景色も見えるんだ。",
"新鮮な魚介を提供する店です。地元の漁師から直接仕入れているので鮮度は抜群ですし、料理人の腕も確かです。",
"あそこは行きにくいけど、隠れた豚骨の名店だよ。スープが最高だし、麺の硬さも好み。",
"おすすめの中華そばの店を教えてあげる。とりわけチャーシューが手作りで柔らかくてジューシーなんだ。",
]
embeddings = model.encode([query] + docs)
print(embeddings.shape)
similarities = model.similarity(embeddings[0], embeddings[1:])
for i, similarity in enumerate(similarities[0].tolist()):
print(f"{similarity:.04f}: {docs[i]}")
```
```
(5, 1024)
0.1040: 素敵なカフェが近所にあるよ。落ち着いた雰囲気でゆっくりできるし、窓際の席からは公園の景色も見えるんだ。
0.2521: 新鮮な魚介を提供する店です。地元の漁師から直接仕入れているので鮮度は抜群ですし、料理人の腕も確かです。
0.4835: あそこは行きにくいけど、隠れた豚骨の名店だよ。スープが最高だし、麺の硬さも好み。
0.3199: おすすめの中華そばの店を教えてあげる。とりわけチャーシューが手作りで柔らかくてジューシーなんだ。
```
このように、queryにマッチする文章のスコアが高くなるように計算できてますね。この例文では、例えばBM25ではqueryに含まれる「ラーメン」のような直接的な単語が文章に出ていないため、うまくマッチさせることが難しいでしょう。
また速度も、CPUで文ベクトルを作った方は少ない文章量でもだいぶ時間がかかるな、という経験をされた方も多いと思いますが、StaticEmbedding モデルではCPUがそこそこ速ければ一瞬で終わると思います。さすが100倍速。
## なぜCPUで推論が高速なの?
StaticEmbedding はTransformerモデルではありません。つまりTrasformerの特徴であるアテンションの計算が一切ないです。文章に出てくる単語トークンを1024次元のテーブルに保存して、文ベクトルではそれの平均をとっているだけです。なお、アテンションがないので、文脈の理解などはしていません。
また PyTorch の nn.EmbeddingBag を使って、全てを連結したトークンとオフセットを渡して処理することで、PyTorch の最適化で高速なCPU並列処理とメモリアクセスがされているようです。
![](https://huggingface.co./datasets/huggingface/documentation-images/resolve/main/blog/static-embeddings/similarity_speed.png)
[元記事の速度評価結果によると](https://huggingface.co./blog/static-embeddings#multilingual-similarity-4)CPUではmE5-smallと比べて126倍速らしいですね。
## 評価結果
JMTEBでの全ての評価結果は[こちらJSONファイルに記載](https://huggingface.co./hotchpotch/static-embedding-japanese/blob/main/JMTEB/summary.json)しています。[JMTEB Leaderboard](https://github.com/sbintuitions/JMTEB/blob/main/leaderboard.md)で見比べると、差がわかるでしょう。JMTEBの全体の評価結果はモデルサイズを考えると、すこぶる良好です。なお、JMTEB で評価された方は、mr-tidy タスクの700万文章のベクトル化に時間がかなりかかる(モデルにもよりますがRTX4090で1~4時間ほど)と思います。これもStaticEmbeddingsでは非常に速く、RTX4090では約4分で処理終えることができました。
### 情報検索でBM25の置き換えができそうか?
JMTEBの中の情報検索タスクの[Retrievalの結果](https://huggingface.co./hotchpotch/static-embedding-japanese/blob/main/JMTEB/summary.json)を見てみましょう。StaticEmbedding では mr-tidy の項目が著しく悪いですね。mr-tidyは他のタスクに比べて文章量が圧倒的に多く(700万文章)、つまる所大量の文章を検索するようなタスクでは結果が悪い可能性がありそうです。文脈を無視したた単純なトークンの平均なので、増えれば増えるほど似た平均の文章が出てくるとすると、そういう結果にもなり得そうですね。
ので、大量の文章の場合、BM25よりもだいぶ性能が悪い可能性がありそうです。ただ、少ない文章で、ずばりの単語マッチが少ない場合は、BM25よりも良好な結果になることが多そうですね。
なお情報検索タスクの jaqket の結果が他のモデルに対してやたら良いのは、JQaRa (dev, unused)を学習しているからといっても高すぎる感じで謎です。test の情報リークはしていないとは思うのですが…。
### クラスタリング結果が悪い
こちらも詳細は追っかけていませんが、スコア的には他のモデルよりもだいぶ悪い結果ですね。クラス分類タスクは悪くないので不思議です。埋め込み空間がマトリョーシカ表現学習で作られた影響もあるのでしょうか。
## JQaRA, JaCWIR でのリランキングタスク
[JQaRA](https://huggingface.co./datasets/hotchpotch/JQaRA) の結果はこちら。
| model_names | ndcg@10 | mrr@10 |
|:-----------------------------------------------------------------------------------------|----------:|---------:|
| [static-embedding-japanese](https://huggingface.co./hotchpotch/static-embedding-japanese) | 0.4704 | 0.6814 |
| bm25 | 0.458 | 0.702 |
| [multilingual-e5-small](https://huggingface.co./intfloat/multilingual-e5-small) | 0.4917 | 0.7291 |
[JaCWIR](https://huggingface.co./datasets/hotchpotch/JaCWIR) の結果はこちら。
| model_names | map@10 | hits@10 |
|:-----------------------------------------------------------------------------------------|---------:|----------:|
| [static-embedding-japanese](https://huggingface.co./hotchpotch/static-embedding-japanese) | 0.7642 | 0.9266 |
| bm25 | 0.8408 | 0.9528 |
| [multilingual-e5-small](https://huggingface.co./intfloat/multilingual-e5-small) | 0.869 | 0.97 |
JQaRa 評価は BM25 よりは若干良く、mE5-small よりは若干低い、JaCWIR は BM25, mE5よりだいぶ低い感じの結果になりました。
JaCWIR はWeb文章のタイトルと概要文なので、いわゆる「綺麗な」文章ではないケースも多く、transformerモデルはノイズに強いので、単純なトークン平均のStaticEmbeddingでは悪い結果になりそうです。BM25は特徴的な単語にマッチしやすいので、JaCWIR でもノイズとなるような単語はクエリにマッチしないため、Transformer モデルと競争力のある結構良い結果を残します。
この結果から、StaticEmbedding は Transformer / BM25 に比べ、ノイズを多く含む文章の場合はスコアが悪い可能性があります。
## 出力次元の削減
StaticEmbedding で出力される次元は、学習次第ですが今回作成したものは1024次元とそこそこのサイズです。次元数が大きいと、推論後のタスク(クラスタリングや情報検索など)に計算コストがかかってしまいます。しかしながら、学習時にマトリョーシカ表現学習([Matryoshka Representation Learning(MRL)](https://arxiv.org/abs/2205.13147))をしているため、1024次元をさらに小さな次元へと簡単に次元削減ができます。
MRLは、学習時に先頭のベクトルほど重要な次元を持ってくることで、例えば1024次元でも先頭の32,64,128,256...次元だけを使って後ろを切り捨てるだけで、ある程度良好な結果を示しています。
![](https://huggingface.co./datasets/huggingface/documentation-images/resolve/main/blog/static-embeddings/nano_beir_matryoshka.png)
このグラフ参照元の[StaticEmbedding の記事](https://huggingface.co./blog/static-embeddings#matryoshka-evaluation)によると、128次元で91.87%, 256次元で95.79%, 512次元で98.53%の性能を維持しているようです。精度にそこまでシビアではないが、その後の計算コストを下げたい場合、ガッと次元削減して使う、という用途にも使えそうですね。
## StaticEmbedding モデルを作ってみて
正直、単純なトークンのembeddingsの平均でそんなに性能出るのか半信半疑だったのですが、実際に学習させてみてシンプルなアーキテクチャなのに性能の高さにびっくりしました。Transformer 全盛のこの時代に、古き良き単語埋め込みの活用モデルで、実世界で利活用できそうなモデルの出現に驚きを隠せません。
CPUでの推論速度が速い文ベクトル作成モデルは、ローカルCPU環境で大量の文章の変換などはもとより、エッジデバイスだったりネットワークが遅い(リモートの推論サーバを叩けない)環境だったり、色々と活用しがいがありそうですね。
---
# StaticEmbedding 日本語モデル学習のテクニカルノート
## なぜうまく学習できるのか
StaticEmbedding は非常にシンプルで、文章をトークナイズしたIDで単語の埋め込みベクトルが格納されているEmbeddingBagテーブルからN次元(今回は1024次元)のベクトルを取得し、その平均を取るだけです。
これまで、単語埋め込みベクトルといえば、word2vec や GloVe のように Skip-gram や CBOW を用いて単語の周辺を学習してきました。しかし、StaticEmbedding では文章全体を用いて学習しています。また、対照学習を使って大量の文章を巨大バッチで学習しており、良い単語の埋め込み表現の学習に成功しているようです。
## 学習データセット
日本語モデル学習にあたり、対照学習で利用できるデータセットとして、以下を作成し使用しました。
- [hotchpotch/sentence_transformer_japanese](https://huggingface.co./datasets/hotchpotch/sentence_transformer_japanese)
- [SentenceTransformer で学習しやすいカラム名と構造](https://sbert.net/docs/sentence_transformer/loss_overview.html)に整えたものです。
- `(anchor, positive)`, `(anchor, positive, negative)`, `(anchor, positive, negative_1, ..., negative_n)` といった構造になっています。
- 以下のデータセットを基に hotchpotch/sentence_transformer_japanese を作成しました。毎度ながらデータセットの作者の方々・とりわけ hpprc 氏に感謝です。
- https://huggingface.co./datasets/hpprc/emb
- https://huggingface.co./datasets/hotchpotch/hpprc_emb-scores のリランカースコアを使用し、positive(>=0.7) / negative(<=0.3) のフィルタリングを行いました。
- https://huggingface.co./datasets/hpprc/llmjp-kaken
- https://huggingface.co./datasets/hpprc/msmarco-ja
- [https://huggingface.co./datasets/hotchpotch/msmarco-ja-hard-negatives](https://huggingface.co./datasets/hotchpotch/msmarco-ja-hard-negatives) のリランカースコアを用いて、positive(>=0.7) / negative(<=0.3) のフィルタリングを行いました。
- https://huggingface.co./datasets/hpprc/mqa-ja
- https://huggingface.co./datasets/hpprc/llmjp-warp-html
- 上記の作成したデータセットの中で、以下を使用しました。なお、情報検索を強化したかったため、情報検索に適したデータセットのデータはオーギュメンテーションで件数を多めに学習させています。
- httprc_auto-wiki-nli-triplet
- httprc_auto-wiki-qa
- httprc_auto-wiki-qa-nemotron
- httprc_auto-wiki-qa-pair
- httprc_baobab-wiki-retrieval
- httprc_janli-triplet
- httprc_jaquad
- httprc_jqara
- httprc_jsnli-triplet
- httprc_jsquad
- httprc_miracl
- httprc_mkqa
- httprc_mkqa-triplet
- httprc_mr-tydi
- httprc_nu-mnli-triplet
- httprc_nu-snli-triplet
- httprc_quiz-no-mori
- httprc_quiz-works
- httprc_snow-triplet
- httprc_llmjp-kaken
- httprc_llmjp_warp_html
- httprc_mqa_ja
- httprc_msmarco_ja
- 英語データセットには、以下のデータセットを利用しています。
- [sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1](https://huggingface.co./datasets/sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1)
- [sentence-transformers/squad](https://huggingface.co./datasets/sentence-transformers/squad)
- [sentence-transformers/all-nli](https://huggingface.co./datasets/sentence-transformers/all-nli)
- [sentence-transformers/trivia-qa](https://huggingface.co./datasets/sentence-transformers/trivia-qa)
- [nthakur/swim-ir-monolingual](https://huggingface.co./datasets/nthakur/swim-ir-monolingual)
- [sentence-transformers/miracl](https://huggingface.co./datasets/sentence-transformers/miracl)
- [sentence-transformers/mr-tydi](https://huggingface.co./datasets/sentence-transformers/mr-tydi)
## 日本語トークナイザ
StaticEmbedding を学習するためには、HuggingFace のトークナイザライブラリの tokenizer.json 形式で処理可能なトークナイザを使うと簡単そうだったので、 [hotchpotch/xlm-roberta-japanese-tokenizer](https://huggingface.co./hotchpotch/xlm-roberta-japanese-tokenizer) というトークナイザを作成しました。語彙数は 32,768 です。
このトークナイザは、wikipedia 日本語、wikipedia 英語(サンプリング)、cc-100(日本語, サンプリング)のデータを unidic で分割し、sentencepiece unigram で学習したものです。XLM-Roberta 形式の日本語トークナイザとしても機能します。今回はこのトークナイザを利用しました。
## ハイパーパラメータ
[大元の学習コード](https://huggingface.co./blog/static-embeddings#english-retrieval-2)との変更点やメモは以下の通りです。
- batch_size を大元の 2048 から 6072 に設定しました。
- 対照学習で巨大なバッチを処理するとき、同一バッチ内にポジティブとネガティブが含まれると学習に悪影響を与える可能性があります。これを防ぐために [BatchSamplers.NO_DUPLICATES](https://sbert.net/docs/package_reference/sentence_transformer/sampler.html) オプションがあります。しかし、バッチサイズが巨大だと同一バッチに含めないためのサンプリング処理に時間がかかることがあります。
- 今回は `BatchSamplers.NO_DUPLICATES` を指定し、RTX4090 の 24GB に収まる 6072 に設定しました。バッチサイズはさらに大きい方が結果が良い可能性があります。
- epoch数を1から2に変更しました
- 1よりも2の方が良い結果になりました。ただし、データサイズがもっと大きければ、1の方が良い可能性があります。
- スケジューラ
- 標準のlinearから、経験則でより良いと感じるcosineに変更しました。
- オプティマイザ
- 標準のAdamW のままです。adafactorに変更した場合、収束が悪くなりました。
- learning_rate
- 2e-1 のままです。値が巨大すぎるのではないかと疑問に思いましたが、低くすると結果が悪化しました。
- dataloader_prefetch_factor=4
- dataloader_num_workers=15
- トークナイズとバッチサンプラのサンプリングに時間がかかるため、大きめに設定しました。
## 学習リソース
- CPU
- Ryzen9 7950X
- GPU
- RTX4090
- memory
- 64GB
このマシンリソースでの学習にかかった時間は約4時間でした。GPUのコア負荷は非常に小さく、他のtransformerモデルでは学習時に90%前後で張り付くのに対して、StaticEmbeddingではほとんど0%でした。これは、巨大なバッチをGPUメモリに転送する時間が大半を占めているためかと思われます。そのため、GPUメモリの帯域幅が速くなれば、学習速度がさらに向上する可能性があります。
## さらなる性能向上へ
今回利用したトークナイザはStaticEmbedding向けに特化したものではないため、より適したトークナイザを使用すれば性能が向上する可能性があります。バッチサイズをさらに巨大化することで、学習の安定性が向上し、性能向上が見込めるかもしれません。
また、さまざまなドメインや合成データセットを利用するなど、より幅広い文章リソースを学習に組み込むことで、さらなる性能向上が期待できます。
## 大元の学習コード
学習に使用したコードは、以下で MIT ライセンスで公開しています。スクリプトを実行すれば再現できる、はず...!
- https://huggingface.co./hotchpotch/static-embedding-japanese/blob/main/trainer.py
## ライセンス
static-embedding-japanese は MIT ライセンスで公開しています。