vonewman's picture
Update app.py
9a156bf verified
raw
history blame
1.99 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import spaces
finetuned_model = "CONCREE/adia-llm"
# Charge le modele
model = AutoModelForCausalLM.from_pretrained(
finetuned_model,
device_map="auto",
trust_remote_code=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(finetuned_model,
trust_remote_code=True,
padding=True,
truncation=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
messages = "".join(["".join(["\n[INST]:"+item[0], "\n[/INST]:"+item[1]]) for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
start_flag = True # Flag to ignore initial newline
for new_token in streamer:
if start_flag and new_token == '\n':
continue
start_flag = False
partial_message += new_token
yield partial_message
demo = gr.ChatInterface(predict).launch()
if __name__ == "__main__":
demo.launch()