File size: 1,751 Bytes
6c09a1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
## Taken from the QLoRA Guanaco demo on Gradio
# https://github.com/artidoro/qlora
# https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import datetime
import os
from threading import Event, Thread
from uuid import uuid4
import requests
model_name = './nyc-savvy'
m = AutoModelForCausalLM.from_pretrained(model_name)
if 'llama' in model_name or 'savvy' in model_name:
tok = LlamaTokenizer.from_pretrained(model_name)
else:
tok = AutoTokenizer.from_pretrained(model_name)
tok.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
max_new_tokens = 1536
messages = "A chat between a curious human and an assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
messages += "### Human: What museums should I visit? - My kids are aged 12 and 5"
messages += "### Assistant: "
input_ids = tok(messages, return_tensors="pt").input_ids
input_ids = input_ids.to(m.device)
temperature = 0.7
top_p = 0.9
top_k = 0
repetition_penalty = 1.1
op = m.generate(
input_ids=input_ids,
max_new_tokens=100,
temperature=temperature,
do_sample=temperature > 0.0,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stopping_criteria=StoppingCriteriaList([stop]),
)
for line in op:
print(tok.decode(line))
|