Spaces:
Restarting
on
T4
Restarting
on
T4
from __future__ import annotations | |
import os | |
import torch | |
import logging | |
from typing import Any, List, Mapping, Optional | |
from langchain_core.callbacks import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
from langchain_core.outputs import Generation, LLMResult | |
from langchain_core.pydantic_v1 import Extra | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
logger = logging.getLogger(__name__) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
DEVICE = "cuda" | |
DEVICE_ID = "0" | |
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE | |
def torch_gc(): | |
if torch.cuda.is_available(): | |
with torch.cuda.device(CUDA_DEVICE): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: | |
num_trans_layers = 32 | |
per_gpu_layers = 32 / num_gpus | |
device_map = {'transformer.word_embeddings': 0, | |
'transformer.final_layernorm': 0, 'lm_head': 0} | |
used = 2 | |
gpu_target = 0 | |
for i in range(num_trans_layers): | |
if used >= per_gpu_layers: | |
gpu_target += 1 | |
used = 0 | |
assert gpu_target < num_gpus | |
device_map[f'transformer.layers.{i}'] = gpu_target | |
used += 1 | |
return device_map | |
class ChatLLM(LLM): | |
max_token: int = 3000 | |
temperature: float = 0.75 | |
top_p = 0.9 | |
tokenizer: object = None | |
model: object = None | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def __init__(self): | |
super().__init__() | |
def from_model_id( | |
self, | |
model_id, | |
device_map: Optional[Dict[str, int]] = None | |
): | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
trust_remote_code=True | |
) | |
if torch.cuda.is_available(): | |
num_gpus = torch.cuda.device_count() | |
if num_gpus < 2 and device_map is None: | |
self.model = ( | |
LlamaForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
load_in_8bit=True, | |
load_in_4bit=False, | |
use_flash_attention_2=False) | |
) | |
else: | |
from accelerate import dispatch_model | |
model = LlamaForCausalLM.from_pretrained(model_id, | |
torch_dtype=torch.float16, | |
load_in_8bit=True, | |
load_in_4bit=False, | |
use_flash_attention_2=False, | |
trust_remote_code=True) | |
if device_map is None: | |
device_map = auto_configure_device_map(num_gpus) | |
self.model = dispatch_model(model, device_map=device_map) | |
else: | |
self.model = ( | |
LlamaForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
load_in_8bit=True, | |
load_in_4bit=False, | |
use_flash_attention_2=False, | |
trust_remote_code=True) | |
) | |
self.model = self.model.eval() | |
def _llm_type(self) -> str: | |
return "ChatLLM" | |
def _call( | |
self, | |
prompts: List[str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None | |
) -> LLMResult: | |
response, _ = self.model.chat( | |
self.tokenizer, | |
prompt, | |
max_length=self.max_token, | |
temperature=self.temperature | |
) | |
torch_gc() | |
if stop is not None: | |
response = enforce_stop_tokens(response, stop) | |
return response |