import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig from peft import PeftModel, PeftConfig from model import Model class KoAlpaca(Model): def __init__(self, name:str='KoAlpaca'): self.name = name peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B" config = PeftConfig.from_pretrained(peft_model_id) self.bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) #self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0}) self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto') self.model = PeftModel.from_pretrained(self.model, peft_model_id) self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json') self.INPUT_FORMAT = "### 질문: \n\n### 답변:" self.model.eval() def generate(self, inputs): inputs = self.INPUT_FORMAT.replace('', inputs) output_ids = self.model.generate( **self.tokenizer( inputs, return_tensors='pt', return_token_type_ids=False ).to(self.model.device), generation_config=self.gen_config ) outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1] return outputs