Text Generation
Transformers
Safetensors
Persian
English
mistral
conversational
text-generation-inference
Inference Endpoints
ZharfaOpen-0309 / eval.py
Habib
[INIT] First Checkpoint (2024-03-09)
13561bb
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model_id = "/share/models/open-zharfa"
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
)
base_model.generation_config.do_sample = True
#tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"
def get_completion_merged(query: str, model, tokenizer) -> str:
device = "cuda:0"
prompt_template = """
GPT4 Correct User: {query}<|end_of_turn|>GPT4 Correct Assistant:
"""
prompt = prompt_template.format(query=query)
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model_inputs = encodeds.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, temperature=0.5, pad_token_id=tokenizer.unk_token_id) #pad_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(generated_ids)
return (decoded[0])
while True:
q = input('q : ')
result = get_completion_merged(query=q, model=base_model, tokenizer=tokenizer)
print(result)