|
from __future__ import annotations |
|
import os |
|
import torch |
|
import logging |
|
from typing import Any, List, Mapping, Optional |
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun |
|
from langchain.llms.base import LLM |
|
from langchain_core.outputs import Generation, LLMResult |
|
from langchain_core.pydantic_v1 import Extra |
|
from transformers import AutoTokenizer, LlamaForCausalLM |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
DEVICE = "cuda" |
|
DEVICE_ID = "0" |
|
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE |
|
|
|
|
|
def torch_gc(): |
|
if torch.cuda.is_available(): |
|
with torch.cuda.device(CUDA_DEVICE): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: |
|
num_trans_layers = 32 |
|
per_gpu_layers = 32 / num_gpus |
|
device_map = {'transformer.word_embeddings': 0, |
|
'transformer.final_layernorm': 0, 'lm_head': 0} |
|
|
|
used = 2 |
|
gpu_target = 0 |
|
for i in range(num_trans_layers): |
|
if used >= per_gpu_layers: |
|
gpu_target += 1 |
|
used = 0 |
|
assert gpu_target < num_gpus |
|
device_map[f'transformer.layers.{i}'] = gpu_target |
|
used += 1 |
|
|
|
return device_map |
|
|
|
|
|
class ChatLLM(LLM): |
|
max_token: int = 3000 |
|
temperature: float = 0.75 |
|
top_p = 0.9 |
|
tokenizer: object = None |
|
model: object = None |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = Extra.forbid |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
def from_model_id( |
|
self, |
|
model_id, |
|
device_map: Optional[Dict[str, int]] = None |
|
): |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True |
|
) |
|
if torch.cuda.is_available(): |
|
num_gpus = torch.cuda.device_count() |
|
if num_gpus < 2 and device_map is None: |
|
self.model = ( |
|
LlamaForCausalLM.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
load_in_8bit=True, |
|
load_in_4bit=False, |
|
use_flash_attention_2=False) |
|
) |
|
else: |
|
from accelerate import dispatch_model |
|
|
|
model = LlamaForCausalLM.from_pretrained(model_id, |
|
torch_dtype=torch.float16, |
|
load_in_8bit=True, |
|
load_in_4bit=False, |
|
use_flash_attention_2=False, |
|
trust_remote_code=True) |
|
if device_map is None: |
|
device_map = auto_configure_device_map(num_gpus) |
|
|
|
self.model = dispatch_model(model, device_map=device_map) |
|
else: |
|
self.model = ( |
|
LlamaForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
load_in_8bit=True, |
|
load_in_4bit=False, |
|
use_flash_attention_2=False, |
|
trust_remote_code=True) |
|
) |
|
self.model = self.model.eval() |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "ChatLLM" |
|
|
|
def _call( |
|
self, |
|
prompts: List[str], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None |
|
) -> LLMResult: |
|
response, _ = self.model.chat( |
|
self.tokenizer, |
|
prompt, |
|
max_length=self.max_token, |
|
temperature=self.temperature |
|
) |
|
torch_gc() |
|
if stop is not None: |
|
response = enforce_stop_tokens(response, stop) |
|
return response |