import torch from typing import List, Union loaded_llm_models = {} def get_llm2vec_embeddings(text: Union[str, List[str]], model_name: str = 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp', peft_model_name: str = 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse', instruction: str = '', device: str = 'cuda', norm=True) -> torch.Tensor: """ Get LLM2Vec embeddings for the given text. Args: text (Union[str, List[str]]): The input text to be embedded. model_name (str): The model to use for embedding. peft_model_name (str): The model to use for PEFT embeddings. Returns: torch.Tensor: The embedding(s) of the input text(s). """ try: from llm2vec import LLM2Vec except ImportError: raise ImportError("Please install the llm2vec package using `pip install llm2vec`.") if peft_model_name in loaded_llm_models: l2v = loaded_llm_models[peft_model_name] else: l2v = LLM2Vec.from_pretrained( model_name, peft_model_name_or_path=peft_model_name, device_map=device, torch_dtype=torch.bfloat16, ) loaded_llm_models[peft_model_name] = l2v if isinstance(text, str): text = [text] if len(instruction) > 0: text = [[instruction, t] for t in text] embeddings = l2v.encode(text, batch_size=len(text)) if norm: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return embeddings.view(len(text), -1) def get_gritlm_embeddings(text: Union[str, List[str]], model_name: str = 'GritLM/GritLM-7B', instruction: str = '', device: str = 'cuda' ) -> torch.Tensor: try: from gritlm import GritLM except ImportError: raise ImportError("Please install the gritlm package using `pip install gritlm`.") def gritlm_instruction(instruction): return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" """ Get GritLM embeddings for the given text. Args: text (Union[str, List[str]]): The input text to be embedded. instruction (str): The instruction to be used for GritLM. model_name (str): The model to use for embedding. Returns: torch.Tensor: The embedding(s) of the input text(s). """ if model_name in loaded_llm_models: gritlm_model = loaded_llm_models[model_name] else: gritlm_model = GritLM(model_name, torch_dtype=torch.bfloat16) loaded_llm_models[model_name] = gritlm_model if isinstance(text, str): text = [text] embeddings = gritlm_model.encode(text, instruction=gritlm_instruction(instruction)) embeddings = torch.from_numpy(embeddings) return embeddings.view(len(text), -1)