YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co./docs/hub/model-cards#model-card-metadata)
Load a model via load_corrector
If you have trained you own custom models using vec2text, you can load them in using the load_corrector function.
def load_corrector(embedder: str) -> vec2text.trainers.Corrector:
"""Gets the Corrector object for the given embedder.
For now, we just support inverting OpenAI Ada 002 embeddings; we plan to
expand this support over time.
"""
assert (
embedder in SUPPORTED_MODELS
), f"embedder to invert `{embedder} not in list of supported models: {SUPPORTED_MODELS}`"
if embedder == "text-embedding-bge":
inversion_model = vec2text.models.InversionModel.from_pretrained(
"ariya2357/vec2text/bge_msl48_inversion_50epochs"
)
model = vec2text.models.CorrectorEncoderModel.from_pretrained(
"ariya2357/vec2text/bge_msl48_corrector_100epochs"
)
from api import load_corrector
corrector = load_corrector("text-embedding-bge")
Invert embeddings with invert_embeddings
take BGE as embedder:
class BGE(BertModel):
def __init__(self, config):
super().__init__(config)
self.model_parallel = False
def forward(self, input_ids, attention_mask):
last_hidden_state = super().forward(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
output = last_hidden_state[:, 0]
output = F.normalize(output, p=2, dim=1)
return output
def get_embeddings_bge(text_list) -> torch.Tensor:
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5",do_lower_case=True)
model = BGE.from_pretrained("BAAI/bge-base-en-v1.5")
encoding = tokenizer(text,add_special_tokens = True, max_length = 48)
with torch.no_grad():
embedding = model(torch.tensor(encoding['input_ids']).unsqueeze(0),torch.tensor(encoding['attention_mask']).unsqueeze(0))
return embedding
from api import load_corrector,invert_embeddings
embedding = get_embeddings_bge("hello world")
corrector = load_corrector("text-embedding-bge")
inverted_text = invert_embeddings(
embeddings=embedding.cuda(),
corrector=corrector,
num_steps=20,
)
print(inverted_text)