File size: 4,014 Bytes
0a6cee5
e17877b
0a6cee5
 
 
 
 
5ce2206
0a6cee5
 
 
 
 
 
 
5ce2206
0a6cee5
5ce2206
 
 
0a6cee5
 
5ce2206
 
 
 
 
0a6cee5
5ce2206
 
 
 
 
0a6cee5
5ce2206
 
 
 
 
 
 
 
 
0a6cee5
5ce2206
0a6cee5
 
5ce2206
 
 
 
 
 
0a6cee5
 
 
 
 
 
5ce2206
 
 
04c4ff0
0a6cee5
04c4ff0
1d5c7ac
 
0a6cee5
1d5c7ac
 
 
 
62d4ea3
5ce2206
 
 
 
2abb3c9
5ce2206
 
 
 
02793b7
5ce2206
 
 
 
2abb3c9
5ce2206
 
 
 
e8544ff
5ce2206
 
 
 
 
 
 
2abb3c9
5ce2206
 
 
15545b5
5ce2206
 
2abb3c9
0a6cee5
 
 
5ce2206
0a6cee5
29c15ab
0a6cee5
 
 
562f33c
0a6cee5
5ce2206
562f33c
 
 
 
 
 
 
 
5ce2206
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations
import os
import torch
import logging
from typing import Any, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import Extra
from transformers import AutoTokenizer, LlamaForCausalLM


logger = logging.getLogger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE


def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
    num_trans_layers = 32
    per_gpu_layers = 32 / num_gpus
    device_map = {'transformer.word_embeddings': 0,
                  'transformer.final_layernorm': 0, 'lm_head': 0}

    used = 2
    gpu_target = 0
    for i in range(num_trans_layers):
        if used >= per_gpu_layers:
            gpu_target += 1
            used = 0
        assert gpu_target < num_gpus
        device_map[f'transformer.layers.{i}'] = gpu_target
        used += 1

    return device_map


class ChatLLM(LLM):
    max_token: int = 3000
    temperature: float = 0.75
    top_p = 0.9
    tokenizer: object = None
    model: object = None

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    def __init__(self):
        super().__init__()

    
    def from_model_id(
        self,
        model_id,
        device_map: Optional[Dict[str, int]] = None
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True
        )
        if torch.cuda.is_available():
            num_gpus = torch.cuda.device_count()
            if num_gpus < 2 and device_map is None:
                self.model = (
                    LlamaForCausalLM.from_pretrained(
                        model_id, 
                        trust_remote_code=True,
                        torch_dtype=torch.float16,
                        load_in_8bit=True,
                        load_in_4bit=False,
                        use_flash_attention_2=False)
                )
            else:
                from accelerate import dispatch_model

                model = LlamaForCausalLM.from_pretrained(model_id, 
                                                         torch_dtype=torch.float16,
                                                         load_in_8bit=True,
                                                         load_in_4bit=False,
                                                         use_flash_attention_2=False,
                                                         trust_remote_code=True)
                if device_map is None:
                    device_map = auto_configure_device_map(num_gpus)

                self.model = dispatch_model(model, device_map=device_map)
        else:
            self.model = (
                LlamaForCausalLM.from_pretrained(
                    model_id,
                    torch_dtype=torch.float16,
                    load_in_8bit=True,
                    load_in_4bit=False,
                    use_flash_attention_2=False,
                    trust_remote_code=True)
            )
        self.model = self.model.eval()

    @property
    def _llm_type(self) -> str:
        return "ChatLLM"

    def _call(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None
    ) -> LLMResult:
        response, _ = self.model.chat(
            self.tokenizer,
            prompt,
            max_length=self.max_token,
            temperature=self.temperature
        )
        torch_gc()
        if stop is not None:
            response = enforce_stop_tokens(response, stop)
        return response