from .vss import VSS from .llm_reranker import LLMReranker from .multi_vss import MultiVSS from .bm25 import BM25 from .colbertv2 import Colbertv2 def get_model(args, skb, **kwargs): model_name = args.model if model_name == 'BM25': return BM25(skb) if model_name == 'Colbertv2': try: return Colbertv2(skb, dataset_name=args.dataset, save_dir=args.output_dir, download_dir=args.download_dir, human_generated_eval=args.split=='human_generated_eval', **kwargs ) except ImportError: raise ImportError("Please install the colbert package using `pip install colbert-ai`.") elif model_name == 'VSS': return VSS( skb, emb_model=args.emb_model, query_emb_dir=args.query_emb_dir, candidates_emb_dir=args.node_emb_dir, device=args.device ) if model_name == 'MultiVSS': return MultiVSS( skb, emb_model=args.emb_model, query_emb_dir=args.query_emb_dir, candidates_emb_dir=args.node_emb_dir, chunk_emb_dir=args.chunk_emb_dir, aggregate=args.aggregate, chunk_size=args.chunk_size, max_k=args.multi_vss_topk, device=args.device ) if model_name == 'LLMReranker': return LLMReranker(skb, emb_model=args.emb_model, llm_model=args.llm_model, query_emb_dir=args.query_emb_dir, candidates_emb_dir=args.node_emb_dir, max_cnt = args.max_retry, max_k=args.llm_topk, device=args.device ) # if model_name == 'Ours': # return GT( # skb, # query_emb_dir=args.query_emb_dir, # candidates_emb_dir=args.node_emb_dir, # ) raise NotImplementedError(f'{model_name} not implemented')