translation / app.py
TenzinGayche's picture
Create app.py
e9bec21 verified
raw
history blame
2.27 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
# Load tokenizer and model
tokenizer = GemmaTokenizerFast.from_pretrained("buddhist-nlp/gemma2-mitra-bo-instruct")
model = AutoModelForCausalLM.from_pretrained("buddhist-nlp/gemma2-mitra-bo-instruct", torch_dtype=torch.float16).to('cuda:0')
# Define custom stopping criteria
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Define stop tokens (adjust based on your model's tokenizer)
stop_ids = [29, 0] # These should be the token IDs for end of response or similar tokens
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
# Define prediction function for the chat interface
def predict(message, history):
# Prepare the conversation in the required format
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
# Concatenate previous messages and the user's input
messages = "".join([f"\n### user : {item[0]} \n### bot : {item[1]}" for item in history_transformer_format])
# Tokenize the input
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
# Set up the streamer for partial message output
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Generate settings
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
)
# Run generation in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Stream partial messages as they are generated
partial_message = ""
for new_token in streamer:
if new_token != '<': # Skip specific tokens if necessary
partial_message += new_token
yield partial_message
# Create the chat interface using Gradio
gr.ChatInterface(fn=predict, title="Gemma LLM Chatbot", description="Chat with the Gemma model using real-time generation and streaming.").launch(share=True)