izhx commited on
Commit
af8c4b6
·
verified ·
1 Parent(s): 2d7b768

Update scripts/gte_embedding.py

Browse files
Files changed (1) hide show
  1. scripts/gte_embedding.py +32 -68
scripts/gte_embedding.py CHANGED
@@ -1,42 +1,25 @@
1
- import logging
2
- from typing import Dict, Optional, List, Tuple
3
- import os
4
-
5
- import heapq
6
- import json
7
- import logging
8
- import os
9
- import queue
10
- import sys
11
- import time
12
- from tqdm import tqdm
13
 
14
- import torch
15
  from collections import defaultdict
16
- from torch.utils.data._utils.worker import ManagerWatchdog
17
- import numpy as np
18
- import torch.distributed as dist
19
- from torch import nn, Tensor
20
- import torch.nn.functional as F
21
- from transformers import AutoModel, AutoTokenizer
22
- from transformers.file_utils import ModelOutput
23
 
24
- logger = logging.getLogger(__name__)
 
 
 
25
 
26
 
27
- class GTEEmbeddidng(nn.Module):
28
  def __init__(self,
29
  model_name: str = None,
30
  normalized: bool = True,
31
- pooling_method: str = 'cls',
32
  use_fp16: bool = True,
33
  device: str = None
34
  ):
35
  super().__init__()
36
- self.load_model(model_name)
37
- self.vocab_size = self.model.config.vocab_size
38
  self.normalized = normalized
39
- self.pooling_method = pooling_method
40
  if device:
41
  self.device = torch.device(device)
42
  else:
@@ -49,40 +32,13 @@ class GTEEmbeddidng(nn.Module):
49
  else:
50
  self.device = torch.device("cpu")
51
  use_fp16 = False
52
- self.model.to(self.device)
53
- self.sparse_linear.to(self.device)
54
- if use_fp16:
55
- self.model.half()
56
- self.sparse_linear.half()
57
-
58
- def load_model(self, model_name):
59
- if not os.path.exists(model_name):
60
- cache_folder = os.getenv('HF_HUB_CACHE')
61
- model_name = snapshot_download(repo_id=model_name,
62
- cache_dir=cache_folder,
63
- ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
64
-
65
- self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
66
- self.sparse_linear = torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
67
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
68
- self.model.eval()
69
- if os.path.exists(os.path.join(model_name, 'sparse_linear.pt')):
70
- logger.info('loading existing sparse_linear---------')
71
- self.load_pooler(model_dir=model_name)
72
- else:
73
- logger.warring('The parameters of sparse linear is not found')
74
-
75
- def dense_embedding(self, hidden_state, mask):
76
- if self.pooling_method == 'cls':
77
- return hidden_state[:, 0]
78
- elif self.pooling_method == 'mean':
79
- s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
80
- d = mask.sum(axis=1, keepdim=True).float()
81
- return s / d
82
-
83
- def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
84
- token_weights = torch.relu(self.sparse_linear(hidden_state))
85
- return token_weights
86
 
87
  def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
88
  # conver to dict
@@ -127,7 +83,7 @@ class GTEEmbeddidng(nn.Module):
127
 
128
  @torch.no_grad()
129
  def _encode(self,
130
- texts: Dict[str, Tensor] = None,
131
  dimension: int = None,
132
  max_length: int = 1024,
133
  batch_size: int = 16,
@@ -136,27 +92,22 @@ class GTEEmbeddidng(nn.Module):
136
 
137
  text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
138
  text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
139
- last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state
140
 
141
  output = {}
142
  if return_dense:
143
- dense_vecs = self.dense_embedding(last_hidden_state, text_input['attention_mask'])
144
- dense_vecs = dense_vecs[:, :dimension]
145
  if self.normalized:
146
  dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
147
  output['dense_embeddings'] = dense_vecs
148
  if return_sparse:
149
- token_weights = self.sparse_embedding(last_hidden_state, text_input['input_ids']).squeeze(-1)
150
  token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
151
  text_input['input_ids'].cpu().numpy().tolist()))
152
  output['token_weights'] = token_weights
153
 
154
  return output
155
 
156
- def load_pooler(self, model_dir):
157
- sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')
158
- self.sparse_linear.load_state_dict(sparse_state_dict)
159
-
160
  def _compute_sparse_scores(self, embs1, embs2):
161
  scores = 0
162
  for token, weight in embs1.items():
@@ -188,3 +139,16 @@ class GTEEmbeddidng(nn.Module):
188
  self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
189
  scores = scores.tolist()
190
  return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
 
 
 
 
 
 
 
 
 
4
 
 
5
  from collections import defaultdict
6
+ from typing import Dict, List, Tuple
 
 
 
 
 
 
7
 
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
11
+ from transformers.utils import is_torch_npu_available
12
 
13
 
14
+ class GTEEmbeddidng(torch.nn.Module):
15
  def __init__(self,
16
  model_name: str = None,
17
  normalized: bool = True,
 
18
  use_fp16: bool = True,
19
  device: str = None
20
  ):
21
  super().__init__()
 
 
22
  self.normalized = normalized
 
23
  if device:
24
  self.device = torch.device(device)
25
  else:
 
32
  else:
33
  self.device = torch.device("cpu")
34
  use_fp16 = False
35
+ self.use_fp16 = use_fp16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ self.model = AutoModelForTokenClassification.from_pretrained(
38
+ model_name, trust_remote_code=True, torch_dtype=torch.float16 if self.use_fp16 else None
39
+ )
40
+ self.vocab_size = self.model.config.vocab_size
41
+ self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
44
  # conver to dict
 
83
 
84
  @torch.no_grad()
85
  def _encode(self,
86
+ texts: Dict[str, torch.Tensor] = None,
87
  dimension: int = None,
88
  max_length: int = 1024,
89
  batch_size: int = 16,
 
92
 
93
  text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
94
  text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
95
+ model_out = self.model(**text_input, return_dict=True)
96
 
97
  output = {}
98
  if return_dense:
99
+ dense_vecs = model_out.last_hidden_state[:, 0, :dimension]
 
100
  if self.normalized:
101
  dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
102
  output['dense_embeddings'] = dense_vecs
103
  if return_sparse:
104
+ token_weights = torch.relu(model_out.logits).squeeze(-1)
105
  token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
106
  text_input['input_ids'].cpu().numpy().tolist()))
107
  output['token_weights'] = token_weights
108
 
109
  return output
110
 
 
 
 
 
111
  def _compute_sparse_scores(self, embs1, embs2):
112
  scores = 0
113
  for token, weight in embs1.items():
 
139
  self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
140
  scores = scores.tolist()
141
  return scores
142
+
143
+
144
+ if __name__ == '__main__':
145
+ gte = GTEEmbeddidng('Alibaba-NLP/gte-multilingual-base')
146
+ docs = [
147
+ "黑龙江离俄罗斯很近",
148
+ "哈尔滨是中国黑龙江省的省会,位于中国东北",
149
+ "you are the hero"
150
+ ]
151
+ print('docs', docs)
152
+ embs = gte.encode(docs, return_dense=True,return_sparse=True)
153
+ print('dense vecs', embs['dense_embeddings'])
154
+ print('sparse vecs', embs['token_weights'])