File size: 4,266 Bytes
93defe7
 
 
 
 
52cd289
 
 
93defe7
 
 
 
 
7d0f396
 
 
 
 
 
 
b346648
52cd289
 
 
 
 
 
 
 
 
 
b346648
 
 
 
 
 
52cd289
 
 
 
 
 
 
 
 
 
 
93defe7
 
b346648
93defe7
 
 
b346648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93defe7
 
b346648
 
 
 
 
93defe7
 
7d0f396
 
 
b346648
7d0f396
b346648
 
 
93defe7
b346648
7d0f396
4bfd3a0
 
 
 
 
 
 
 
 
 
93defe7
 
 
b346648
 
 
 
 
 
 
93defe7
 
 
 
b346648
 
 
4bfd3a0
b346648
 
 
93defe7
52cd289
 
93defe7
b346648
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import json
import os

from huggingface_hub import hf_hub_download
from llama_cpp import Llama

from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel

SYS_PREFIX = "<<SYS>>\n"
SYS_POSTFIX = "\n<</SYS>>\n\n"
INST_PREFIX = "<s>[INST] "
INST_POSTFIX = " "
OUTPUT_PREFIX = "[/INST] "
OUTPUT_POSTFIX = "</s>"


def download(repo_id, filename, retry=10):
    if os.path.exists("./models/downloaded_models.json"):
        with open("./models/downloaded_models.json", "r") as f:
            downloaded_models = json.load(f)
        if repo_id in downloaded_models:
            return downloaded_models[repo_id]["path"]
    else:
        downloaded_models = {}
    while retry > 0:
        try:
            model_path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                cache_dir="models",
                resume_download=True,
            )
            downloaded_models[repo_id] = {"path": model_path}
            with open("./models/downloaded_models.json", "w") as f:
                json.dump(downloaded_models, f)
            break
        except:
            print("Error downloading model, retrying...")
            retry -= 1
    if retry == 0:
        raise Exception("Error downloading model, please try again later.")
    return model_path


class LLaMA_Client(BaseLLMModel):
    def __init__(self, model_name, lora_path=None, user_name="") -> None:
        super().__init__(model_name=model_name, user=user_name)

        self.max_generation_token = 1000
        if model_name in MODEL_METADATA:
            path_to_model = download(
                MODEL_METADATA[model_name]["repo_id"],
                MODEL_METADATA[model_name]["filelist"][0],
            )
        else:
            dir_to_model = os.path.join("models", model_name)
            # look for nay .gguf file in the dir_to_model directory and its subdirectories
            path_to_model = None
            for root, dirs, files in os.walk(dir_to_model):
                for file in files:
                    if file.endswith(".gguf"):
                        path_to_model = os.path.join(root, file)
                        break
                if path_to_model is not None:
                    break
        self.system_prompt = ""

        if lora_path is not None:
            lora_path = os.path.join("lora", lora_path)
            self.model = Llama(model_path=path_to_model, lora_path=lora_path)
        else:
            self.model = Llama(model_path=path_to_model)

    def _get_llama_style_input(self):
        context = []
        for conv in self.history:
            if conv["role"] == "system":
                context.append(SYS_PREFIX + conv["content"] + SYS_POSTFIX)
            elif conv["role"] == "user":
                context.append(
                    INST_PREFIX + conv["content"] + INST_POSTFIX + OUTPUT_PREFIX
                )
            else:
                context.append(conv["content"] + OUTPUT_POSTFIX)
        return "".join(context)
        # for conv in self.history:
        #     if conv["role"] == "system":
        #         context.append(conv["content"])
        #     elif conv["role"] == "user":
        #         context.append(
        #             conv["content"]
        #         )
        #     else:
        #         context.append(conv["content"])
        # return "\n\n".join(context)+"\n\n"

    def get_answer_at_once(self):
        context = self._get_llama_style_input()
        response = self.model(
            context,
            max_tokens=self.max_generation_token,
            stop=[],
            echo=False,
            stream=False,
        )
        return response, len(response)

    def get_answer_stream_iter(self):
        context = self._get_llama_style_input()
        iter = self.model(
            context,
            max_tokens=self.max_generation_token,
            stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX,OUTPUT_POSTFIX],
            echo=False,
            stream=True,
        )
        partial_text = ""
        for i in iter:
            response = i["choices"][0]["text"]
            partial_text += response
            yield partial_text