from __future__ import annotations import os import torch import logging from typing import Any, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain_core.outputs import Generation, LLMResult from langchain_core.pydantic_v1 import Extra from transformers import AutoTokenizer, LlamaForCausalLM logger = logging.getLogger(__name__) os.environ["TOKENIZERS_PARALLELISM"] = "false" DEVICE = "cuda" DEVICE_ID = "0" CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: num_trans_layers = 32 per_gpu_layers = 32 / num_gpus device_map = {'transformer.word_embeddings': 0, 'transformer.final_layernorm': 0, 'lm_head': 0} used = 2 gpu_target = 0 for i in range(num_trans_layers): if used >= per_gpu_layers: gpu_target += 1 used = 0 assert gpu_target < num_gpus device_map[f'transformer.layers.{i}'] = gpu_target used += 1 return device_map class ChatLLM(LLM): max_token: int = 3000 temperature: float = 0.75 top_p = 0.9 tokenizer: object = None model: object = None class Config: """Configuration for this pydantic object.""" extra = Extra.forbid def __init__(self): super().__init__() def from_model_id( self, model_id, device_map: Optional[Dict[str, int]] = None ): self.tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) if torch.cuda.is_available(): num_gpus = torch.cuda.device_count() if num_gpus < 2 and device_map is None: self.model = ( LlamaForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.float16, load_in_8bit=True, load_in_4bit=False, use_flash_attention_2=False) ) else: from accelerate import dispatch_model model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, load_in_8bit=True, load_in_4bit=False, use_flash_attention_2=False, trust_remote_code=True) if device_map is None: device_map = auto_configure_device_map(num_gpus) self.model = dispatch_model(model, device_map=device_map) else: self.model = ( LlamaForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, load_in_8bit=True, load_in_4bit=False, use_flash_attention_2=False, trust_remote_code=True) ) self.model = self.model.eval() @property def _llm_type(self) -> str: return "ChatLLM" def _call( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None ) -> LLMResult: response, _ = self.model.chat( self.tokenizer, prompt, max_length=self.max_token, temperature=self.temperature ) torch_gc() if stop is not None: response = enforce_stop_tokens(response, stop) return response