File size: 755 Bytes
abd6171
 
 
 
 
 
 
 
1146dae
abd6171
 
 
1146dae
abd6171
 
 
1146dae
abd6171
 
 
 
 
 
 
 
 
 
 
1146dae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from peft import PeftModel, PeftConfig
import torch

orig_checkpoint = 'google/gemma-2b'
checkpoint = '.'
HF_TOKEN = ''
PROMPT = 'Salut, ca sa imi schimb buletinul pot sa'

seq_len = 256

# load original model first
tokenizer = AutoTokenizer.from_pretrained(orig_checkpoint, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(orig_checkpoint, token=HF_TOKEN)

# then merge trained QLoRA weights
model = PeftModel.from_pretrained(model, checkpoint)
model.merge_and_unload()

model = model.cuda()

# generate normally
inputs = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
outputs = model.generate(inputs, max_new_tokens=seq_len)

print(tokenizer.decode(outputs[0]))