File size: 4,659 Bytes
ee02a28
3066997
ee02a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d95871f
ee02a28
 
3a29dde
 
 
 
 
 
 
 
 
 
60f31de
 
 
 
 
 
ee02a28
60f31de
 
 
 
 
 
3a29dde
067f161
ee02a28
60f31de
dc1fe45
3a29dde
60f31de
 
3a29dde
 
 
dc1fe45
d95871f
 
 
 
 
 
067f161
 
 
3066997
 
97e04f5
3066997
ee02a28
3066997
 
ee02a28
 
 
 
 
13afdfa
3a29dde
 
d95871f
 
ee02a28
60f31de
 
 
 
ee02a28
60f31de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda23ff
60f31de
 
 
ee02a28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from threading import Thread
from transformers import TextStreamer, TextIteratorStreamer
from unsloth import FastLanguageModel
import torch
import gradio as gr

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model_name = "Danielrahmai1991/llama32_ganjoor_adapt_basic_model_16bit_v1"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    trust_remote_code=True,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
FastLanguageModel.for_inference(model)
print("model loaded")

import re
from deep_translator import (GoogleTranslator,
                             PonsTranslator,
                             LingueeTranslator,
                             MyMemoryTranslator,
                             YandexTranslator,
                             DeeplTranslator,
                             QcriTranslator,
                             single_detection,
                             batch_detection)
from pyaspeller import YandexSpeller
def error_correct_pyspeller(sample_text):
    """ grammer correction of input text"""
    speller = YandexSpeller()
    fixed = speller.spelled(sample_text)
    return fixed

def postprocerssing(inp_text: str):
    """Post preocessing of the llm response"""
    inp_text = re.sub('<[^>]+>', '', inp_text)
    inp_text = inp_text.split('##', 1)[0]
    inp_text = error_correct_pyspeller(inp_text)
    return inp_text
    


def llm_run(prompt, max_length, top_p, temprature, top_k, messages):
    print("prompt, max_length, top_p, temprature, top_k, messages", prompt, max_length, top_p, temprature, top_k, messages)
    lang = single_detection(prompt, api_key='4ab77f25578d450f0902fb42c66d5e11')
    if lang == 'en':
        prompt = error_correct_pyspeller(prompt)
    en_translated = GoogleTranslator(source='auto', target='en').translate(prompt)
    messages.append({"role": "user", "content": en_translated})
    # messages.append({"role": "user", "content": prompt})           
    print("messages", messages)
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        return_tensors = "pt",
    )

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        max_length=int(max_length),top_p=float(top_p), do_sample=True,
        top_k=int(top_k), streamer=streamer, temperature=float(temprature), repetition_penalty=1.2
    )

    t = Thread(target=model.generate,  args=(input_ids,), kwargs=generate_kwargs)
    t.start()

    generated_text=[]

    for text in streamer:
        generated_text.append(text)
        print('generated_text: ', generated_text)
        # yield "".join(generated_text)
        yield GoogleTranslator(source='auto', target=lang).translate("".join(generated_text))
    
    messages.append({"role": "assistant", "content": "".join(generated_text)})   

def clear_memory(messages):
    messages.clear()
    return "Memory cleaned."
    

with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.orange, secondary_hue=gr.themes.colors.pink)) as demo:
    stored_message = gr.State([])
    with gr.Row():
        with gr.Column(scale=2):
            prompt_text = gr.Textbox(lines=7, label="Prompt", scale=2)
            with gr.Row():
                btn1 = gr.Button("Submit", scale=1)
                btn2 = gr.Button("Clear",  scale=1)
                btn3 = gr.Button("Clean Memory",  scale=2)
        with gr.Column(scale=2):
            out_text = gr.Text(lines=15, label="Output", scale=2)
    btn1.click(fn=llm_run, inputs=[
                                prompt_text, 
                                gr.Textbox(label="Max-Lenth generation", value=500),
                                gr.Slider(0.0, 1.0, label="Top-P value", value=0.90),
                                gr.Slider(0.0, 1.0, label="Temprature value", value=0.65),
                                gr.Textbox(label="Top-K", value=50,),
                                stored_message
                                  ], outputs=out_text)
    btn2.click(lambda: [None, None], outputs=[prompt_text, out_text])
    btn3.click(fn=clear_memory, inputs=[stored_message], outputs=[out_text])
    
# demo = gr.Interface(fn=llm_run, inputs=["text"], outputs="text")
demo.launch(debug=True, share=True)