Text Generation
Transformers
Safetensors
Persian
English
mistral
conversational
text-generation-inference
Inference Endpoints
File size: 1,448 Bytes
13561bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model

model_id = "/share/models/open-zharfa"

tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

base_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            low_cpu_mem_usage=True,
            return_dict=True,
            torch_dtype=torch.float16,
            device_map="auto",
        )

base_model.generation_config.do_sample = True 


#tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"

def get_completion_merged(query: str, model, tokenizer) -> str:
    device = "cuda:0"

    prompt_template = """

          GPT4 Correct User: {query}<|end_of_turn|>GPT4 Correct Assistant:

    """
    prompt = prompt_template.format(query=query)

    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

    model_inputs = encodeds.to(device)

    generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, temperature=0.5, pad_token_id=tokenizer.unk_token_id)    #pad_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.batch_decode(generated_ids)
    return (decoded[0])

while True:
    q = input('q : ')
    result = get_completion_merged(query=q, model=base_model, tokenizer=tokenizer)
    print(result)