import os import os.path as osp import random import sys import argparse import pandas as pd import torch from tqdm import tqdm from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings sys.path.append('.') from stark_qa import load_skb, load_qa from stark_qa.tools.api import get_api_embeddings from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings from models.model import get_embeddings import argparse def parse_args(): parser = argparse.ArgumentParser() # Dataset and embedding model selection parser.add_argument('--dataset', default='prime', choices=['amazon', 'prime', 'mag']) parser.add_argument('--emb_model', default='contriever', choices=[ 'text-embedding-ada-002', 'text-embedding-3-small', 'text-embedding-3-large', 'voyage-large-2-instruct', 'GritLM/GritLM-7B', 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp', 'all-mpnet-base-v2' # for sentence transformer ] ) # Mode settings parser.add_argument('--mode', default='query', choices=['doc', 'query']) # Path settings parser.add_argument("--data_dir", default="data/", type=str) parser.add_argument("--emb_dir", default="emb/", type=str) # Text settings parser.add_argument('--add_rel', action='store_true', default=False, help='add relation to the text') parser.add_argument('--compact', action='store_true', default=False, help='make the text compact when input to the model') # Evaluation settings parser.add_argument("--human_generated_eval", action="store_true", help="if mode is `query`, then generating query embeddings on human generated evaluation split") # Batch and node settings parser.add_argument("--batch_size", default=1024, type=int) # encode kwargs parser.add_argument("--n_max_nodes", default=None, type=int, metavar="ENCODE") parser.add_argument("--device", default=None, type=str, metavar="ENCODE") parser.add_argument("--peft_model_name", default=None, type=str, help="llm2vec pdft model", metavar="ENCODE") parser.add_argument("--instruction", type=str, help="gritl/llm2vec instruction", metavar="ENCODE") args = parser.parse_args() # Create encode_kwargs based on the custom metavar "ENCODE" encode_kwargs = {k: v for k, v in vars(args).items() if v is not None and parser._option_string_actions[f'--{k}'].metavar == "ENCODE"} return args, encode_kwargs if __name__ == '__main__': args, encode_kwargs = parse_args() args.human_generated_eval = False mode_surfix = '_human_generated_eval' if args.human_generated_eval and args.mode == 'query' else '' mode_surfix += '_no_rel' if not args.add_rel else '' mode_surfix += '_no_compact' if not args.compact else '' emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, f'{args.mode}{mode_surfix}') csv_cache = osp.join(args.data_dir, args.dataset, f'{args.mode}{mode_surfix}.csv') print(f'Embedding directory: {emb_dir}') os.makedirs(emb_dir, exist_ok=True) os.makedirs(os.path.dirname(csv_cache), exist_ok=True) if args.mode == 'doc': skb = load_skb(args.dataset) lst = skb.candidate_ids emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt') if args.mode == 'query': qa_dataset = load_qa(args.dataset, human_generated_eval=args.human_generated_eval) lst = [qa_dataset[i][1] for i in range(len(qa_dataset))] emb_path = osp.join(emb_dir, f'query_emb_dict.pt') random.shuffle(lst) # Load existing embeddings if they exist if osp.exists(emb_path): emb_dict = torch.load(emb_path) exist_emb_indices = list(emb_dict.keys()) print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}') else: emb_dict = {} exist_emb_indices = [] # Load existing document cache if it exists (only for doc mode) if args.mode == 'doc' and osp.exists(csv_cache): df = pd.read_csv(csv_cache) cache_dict = dict(zip(df['index'], df['text'])) # Ensure that the indices in the cache match the expected indices assert set(cache_dict.keys()) == set(lst), 'Indices in cache do not match the candidate indices.' indices = list(set(lst) - set(exist_emb_indices)) texts = [cache_dict[idx] for idx in tqdm(indices, desc="Filtering docs for new embeddings")] else: indices = lst texts = [qa_dataset.get_query_by_qid(idx) if args.mode == 'query' else skb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact) for idx in tqdm(indices, desc="Gathering docs")] if args.mode == 'doc': df = pd.DataFrame({'index': indices, 'text': texts}) df.to_csv(csv_cache, index=False) print(f'Generating embeddings for {len(texts)} texts...') if args.emb_model == 'contriever': encoder, tokenizer = get_contriever(dataset_name=args.dataset) for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"): batch_texts = texts[i:i+args.batch_size] batch_embs = get_contriever_embeddings(batch_texts, encoder=encoder, tokenizer=tokenizer, device='cuda') batch_embs = batch_embs.view(len(batch_texts), -1).cpu() batch_indices = indices[i:i+args.batch_size] for idx, emb in zip(batch_indices, batch_embs): emb_dict[idx] = emb.view(1, -1) else: for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"): batch_texts = texts[i:i+args.batch_size] batch_embs = get_embeddings(batch_texts, args.emb_model, **encode_kwargs) batch_embs = batch_embs.view(len(batch_texts), -1).cpu() batch_indices = indices[i:i+args.batch_size] for idx, emb in zip(batch_indices, batch_embs): emb_dict[idx] = emb.view(1, -1) torch.save(emb_dict, emb_path) print(f'Saved {len(emb_dict)} embeddings to {emb_path}!')