Multipurpose-AI-Agent-Development / langchain_transformers.py
devve1's picture
Update langchain_transformers.py
02793b7 verified
raw
history blame
No virus
4.01 kB
from __future__ import annotations
import os
import torch
import logging
from typing import Any, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import Extra
from transformers import AutoTokenizer, LlamaForCausalLM
logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
num_trans_layers = 32
per_gpu_layers = 32 / num_gpus
device_map = {'transformer.word_embeddings': 0,
'transformer.final_layernorm': 0, 'lm_head': 0}
used = 2
gpu_target = 0
for i in range(num_trans_layers):
if used >= per_gpu_layers:
gpu_target += 1
used = 0
assert gpu_target < num_gpus
device_map[f'transformer.layers.{i}'] = gpu_target
used += 1
return device_map
class ChatLLM(LLM):
max_token: int = 3000
temperature: float = 0.75
top_p = 0.9
tokenizer: object = None
model: object = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def __init__(self):
super().__init__()
def from_model_id(
self,
model_id,
device_map: Optional[Dict[str, int]] = None
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
if num_gpus < 2 and device_map is None:
self.model = (
LlamaForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float16,
load_in_8bit=True,
load_in_4bit=False,
use_flash_attention_2=False)
)
else:
from accelerate import dispatch_model
model = LlamaForCausalLM.from_pretrained(model_id,
torch_dtype=torch.float16,
load_in_8bit=True,
load_in_4bit=False,
use_flash_attention_2=False,
trust_remote_code=True)
if device_map is None:
device_map = auto_configure_device_map(num_gpus)
self.model = dispatch_model(model, device_map=device_map)
else:
self.model = (
LlamaForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
load_in_8bit=True,
load_in_4bit=False,
use_flash_attention_2=False,
trust_remote_code=True)
)
self.model = self.model.eval()
@property
def _llm_type(self) -> str:
return "ChatLLM"
def _call(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None
) -> LLMResult:
response, _ = self.model.chat(
self.tokenizer,
prompt,
max_length=self.max_token,
temperature=self.temperature
)
torch_gc()
if stop is not None:
response = enforce_stop_tokens(response, stop)
return response