translation / app.py
TenzinGayche's picture
Update app.py
25cc8b2 verified
import os
from threading import Thread, Event
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
from gradio import ChatMessage
# Constants and model initialization
path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft"
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
model.config.sliding_window = 4096
model.eval()
model.config.use_cache = True
stop_event = Event()
def stream_translation(user_message: str, messages: list) -> Iterator[list]:
stop_event.clear()
message = user_message.replace('\n', ' ')
# Initialize the chat history if empty
if not messages:
messages = []
# Add user message if not already present
if not messages or (isinstance(messages[-1], dict) and messages[-1]["role"] != "user"):
messages.append({"role": "user", "content": message})
# Prepare input for the model
conversation = [
{"role": "user", "content": f"Please translate the following into English: {message} Translation:"}
]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
# Generation parameters
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=2000,
)
# Start generation in a separate thread
Thread(target=model.generate, kwargs=generate_kwargs).start()
# Initialize response tracking
thought_buffer = ""
translation_buffer = ""
in_translation = False
accumulated_text = ""
# Add initial thinking message
messages.append({
"role": "assistant",
"content": "",
"metadata": {"title": "🤔 Thinking...", "status": "pending"}
})
yield messages
for text in streamer:
accumulated_text += text
# Check for the marker in the accumulated text
if "#Final Translation:" in accumulated_text and not in_translation:
# Split at the marker and handle both parts
parts = accumulated_text.split("#Final Translation:", 1)
thought_buffer = parts[0].strip()
translation_start = parts[1] if len(parts) > 1 else ""
# Complete the thinking phase
messages[-1] = {
"role": "assistant",
"content": thought_buffer,
"metadata": {"title": "🤔 Thought Process", "status": "done"},
"collapsed": True
}
yield messages
thought_buffer=""
# Start translation phase as a normal message
in_translation = True
messages.append({
"role": "assistant",
"content": translation_start.strip() # No metadata for normal response
})
translation_buffer = translation_start
yield messages
continue
if in_translation:
translation_buffer += text
messages[-1] = {
"role": "assistant",
"content": translation_buffer.strip() # No metadata for normal response
}
else:
thought_buffer += text
messages[-1] = {
"role": "assistant",
"content": thought_buffer.strip(),
"metadata": {"title": "🤔 Thinking...", "status": "pending"}
}
yield messages
with gr.Blocks(title="Monlam Translation System", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 💭 Samloe Melong Translate")
gr.Markdown("It's a proof of concept. The model first generates detailed reasoning and then provides the translation. It only works for Tibetan to English (for now)!!")
chatbot = gr.Chatbot(
type="messages",
show_label=False,
render_markdown=True,
scale=1
)
with gr.Row():
input_box = gr.Textbox(
lines=3,
label="Enter Tibetan text",
placeholder="Type Tibetan text here...",
show_label=True,
)
submit_btn = gr.Button("Translate", variant="primary", scale=0.15)
examples = [
["རྟག་པར་མངོན་ཞེན་གྱིས་བསླང་བའམ། །གཉེན་པོ་ཡིས་ནི་བསླང་བ་ཉིད། །ཡོན་ཏན་དང་ནི་ཕན་འདོགས་ཞིང་། །སྡུག་བསྔལ་བ་ལ་དགེ་ཆེན་འགྱུར། "],
["ད་ཆ་ཨ་རིའི་ཚོང་རའི་ནང་དུ་གླེང་གཞི་ཤུགས་ཆེར་འགྱུར་བཞིན་པའི་ Deep Seek ཞེས་རྒྱ་ནག་གི་མི་བཟོས་རིག་ནུས་མཉེན་ཆས་དེས་བོད་ནང་དུ་དེ་སྔ་ནས་དམ་དྲག་ཤུགས་ཆེ་ཡོད་པའི་ཐོག་ད་དུང་ཤུགས་ཆེ་རུ་གཏོང་སྲིད་པ་གསུངས་སོང་།"],
["མཉེན་ཆས་འདི་བཞིན་ཨ་རི་དང་རྒྱ་ནག་གཉིས་དབར་ཚོང་འབྲེལ་བཀག་སྡོམ་གྱི་གནད་དོན་ཁྲོད་ཨ་རིའི་མི་བཟོས་རིག་ནུས་ཀྱི་ Chips ཅིབ་སེ་མ་ལག་རྒྱ་ནག་ནང་དུ་ཚོང་འགྲེམ་བཀག་སྡོམ་བྱས་མིན་ལ་མ་ལྟོས་པར། ཚོང་འབྲེལ་བཀག་སྡོམ་གྱི་སྔོན་ཚུད་ནས་རྒྱ་ནག་གི་ཉོ་ཚོང་བྱས་པའི་ཅིབ་སེ་མ་ལག་དོན་ཐེངས་རྙིང་པའི་ཐོག་བཟོ་བསྐྲུན་བྱས་པ་དང་། ཨ་སྒོར་ཐེར་འབུམ་མང་པོའི་འགྲོ་གྲོན་ཐོག་བཟོ་བསྐྲུན་བྱས་པའི་ AI འམ་མི་བཟོས་རིག་ནུས་ཀྱི་མཉེན་ཆས་གཞན་དང་མི་འདྲ་བར་ Deep seek མཉེན་ཆས་དེ་བཞིན་ཨ་སྒོར་ས་ཡ་ ༦ ཁོ་ནའི་འགྲོ་གྲོན་ཐོག་བཟོ་བསྐྲུན་བྱས་པའི་གནད་དོན་སོགས་ཀྱི་རྐྱེན་པས་ཁ་སང་ཨ་རིའི་ཚོང་རའི་ནང་དུ་མི་བཟོས་རིག་ནུས་མཉེན་ཆས་འཕྲུལ་རིག་གི་ Chips ཅིབ་སེ་མ་ལག་བཟོ་བསྐྲུན་བྱས་མཁན་ NVidia ལྟ་བུར་ཨ་སྒོར་ཐེར་འབུམ་ ༦ མིན་ཙམ་གྱི་གྱོན་རྒུད་ཕོག་པའི་གནས་ཚུལ་བྱུང་ཡོད་འདུག"],
["དེ་ཡང་དེ་རིང་ BBC དང་ Reuters སོགས་རྒྱལ་སྤྱིའི་བརྒྱུད་ལམ་ཁག་གི་གནས་ཚུལ་སྤེལ་བར་གཞིགས་ན། རྒྱ་ནག་གི་ Huangzhou གྲོང་ཁྱེར་ནང་དུ་བཟོ་བསྐྲུན་བྱས་པའི་ Deep Seek མི་བཟོས་རིག་ནུས་མཉེན་ཆས་དེ་བཞིན་ ChatGPT དང་ Gemini སོགས་མི་བཟོས་རིག་ནུས་ཀྱི་མཉེན་ཆས་གཞན་དང་བསྡུར་ན་མགྱོགས་ཚད་དང་ནུས་པའི་ཆ་ནས་གཅིག་མཚུངས་ཡོད་པ་མ་ཟད། མཉེན་ཆས་དེ་ཉིད་རིན་མེད་ཡིན་པའི་ཆ་ནས་ཨ་རི་དང་ཨིན་ཡུལ། དེ་བཞིན་རྒྱ་ནག་གསུམ་གྱི་ནང་དུ་སྐུ་ཤུ་རྟགས་ཅན་འཕྲུལ་རིག་གི་ App Store གཉེན་ཆས་བངས་མཛོད་ནང་ནས་ Deep Seek དེ་ཉིད་མཉེན་ཆས་ཕབ་ལེན་མང་ཤོས་བྱས་པ་ཞིག་ཆགས་ཡོད་པ་རེད་འདུག"],
]
gr.Examples(
examples=examples,
inputs=[input_box],
label="Try these examples",
examples_per_page=3
)
# Connect components with correct inputs
submit_btn.click(
fn=stream_translation,
inputs=[input_box, chatbot],
outputs=chatbot
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)