File size: 757 Bytes
0685af6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from transformers import AutoTokenizer, AutoModel
from chromadb import Documents, EmbeddingFunction, Embeddings


model_name = "YituTech/conv-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)


class MyEmbeddingFunction(EmbeddingFunction[Documents]):

    def __call__(self, input: Documents) -> Embeddings:
        embeddings_list = []
        
        for text in input:
            tokens = tokenizer(text, return_tensors='pt')
            with torch.no_grad():
                outputs = model(**tokens)
            embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
            embeddings_list.append(embeddings)
        
        return embeddings_list