alexghergh commited on
Commit
1146dae
1 Parent(s): 2c2e591

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +4 -6
inference.py CHANGED
@@ -6,17 +6,15 @@ from peft import PeftModel, PeftConfig
6
  import torch
7
 
8
  orig_checkpoint = 'google/gemma-2b'
9
- checkpoint = 'checkpoint-4000'
10
  HF_TOKEN = ''
11
  PROMPT = 'Salut, ca sa imi schimb buletinul pot sa'
12
 
13
- seq_len = 2048
14
 
15
  # load original model first
16
  tokenizer = AutoTokenizer.from_pretrained(orig_checkpoint, token=HF_TOKEN)
17
-
18
- config = PeftConfig.from_pretrained(checkpoint)
19
- model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, token=HF_TOKEN)
20
 
21
  # then merge trained QLoRA weights
22
  model = PeftModel.from_pretrained(model, checkpoint)
@@ -28,4 +26,4 @@ model = model.cuda()
28
  inputs = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
29
  outputs = model.generate(inputs, max_new_tokens=seq_len)
30
 
31
- print(tokenizer.decode(outputs[0]))
 
6
  import torch
7
 
8
  orig_checkpoint = 'google/gemma-2b'
9
+ checkpoint = '.'
10
  HF_TOKEN = ''
11
  PROMPT = 'Salut, ca sa imi schimb buletinul pot sa'
12
 
13
+ seq_len = 256
14
 
15
  # load original model first
16
  tokenizer = AutoTokenizer.from_pretrained(orig_checkpoint, token=HF_TOKEN)
17
+ model = AutoModelForCausalLM.from_pretrained(orig_checkpoint, token=HF_TOKEN)
 
 
18
 
19
  # then merge trained QLoRA weights
20
  model = PeftModel.from_pretrained(model, checkpoint)
 
26
  inputs = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
27
  outputs = model.generate(inputs, max_new_tokens=seq_len)
28
 
29
+ print(tokenizer.decode(outputs[0]))