import os from threading import Thread, Event from typing import Iterator import gradio as gr import torch from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer from gradio import ChatMessage # Constants and model initialization path = "TenzinGayche/tpo_v1.0.0_dpo_2_3ep_ft" MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Load the model and tokenizer tokenizer = GemmaTokenizerFast.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda") model.config.sliding_window = 4096 model.eval() model.config.use_cache = True stop_event = Event() def stream_translation(user_message: str, messages: list) -> Iterator[list]: stop_event.clear() message = user_message.replace('\n', ' ') # Initialize the chat history if empty if not messages: messages = [] # Add user message if not already present if not messages or (isinstance(messages[-1], dict) and messages[-1]["role"] != "user"): messages.append({"role": "user", "content": message}) # Prepare input for the model conversation = [ {"role": "user", "content": f"Please translate the following into English: {message} Translation:"} ] input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Input trimmed as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) # Generation parameters generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=2000, ) # Start generation in a separate thread Thread(target=model.generate, kwargs=generate_kwargs).start() # Initialize response tracking thought_buffer = "" translation_buffer = "" in_translation = False accumulated_text = "" # Add initial thinking message messages.append({ "role": "assistant", "content": "", "metadata": {"title": "🤔 Thinking...", "status": "pending"} }) yield messages for text in streamer: accumulated_text += text # Check for the marker in the accumulated text if "#Final Translation:" in accumulated_text and not in_translation: # Split at the marker and handle both parts parts = accumulated_text.split("#Final Translation:", 1) thought_buffer = parts[0].strip() translation_start = parts[1] if len(parts) > 1 else "" # Complete the thinking phase messages[-1] = { "role": "assistant", "content": thought_buffer, "metadata": {"title": "🤔 Thought Process", "status": "done"}, "collapsed": True } yield messages thought_buffer="" # Start translation phase as a normal message in_translation = True messages.append({ "role": "assistant", "content": translation_start.strip() # No metadata for normal response }) translation_buffer = translation_start yield messages continue if in_translation: translation_buffer += text messages[-1] = { "role": "assistant", "content": translation_buffer.strip() # No metadata for normal response } else: thought_buffer += text messages[-1] = { "role": "assistant", "content": thought_buffer.strip(), "metadata": {"title": "🤔 Thinking...", "status": "pending"} } yield messages with gr.Blocks(title="Monlam Translation System", theme=gr.themes.Soft()) as demo: gr.Markdown("# 💭 Samloe Melong Translate") gr.Markdown("It's a proof of concept. The model first generates detailed reasoning and then provides the translation. It only works for Tibetan to English (for now)!!") chatbot = gr.Chatbot( type="messages", show_label=False, render_markdown=True, scale=1 ) with gr.Row(): input_box = gr.Textbox( lines=3, label="Enter Tibetan text", placeholder="Type Tibetan text here...", show_label=True, ) submit_btn = gr.Button("Translate", variant="primary", scale=0.15) examples = [ ["རྟག་པར་མངོན་ཞེན་གྱིས་བསླང་བའམ། །གཉེན་པོ་ཡིས་ནི་བསླང་བ་ཉིད། །ཡོན་ཏན་དང་ནི་ཕན་འདོགས་ཞིང་། །སྡུག་བསྔལ་བ་ལ་དགེ་ཆེན་འགྱུར། "], ["ད་ཆ་ཨ་རིའི་ཚོང་རའི་ནང་དུ་གླེང་གཞི་ཤུགས་ཆེར་འགྱུར་བཞིན་པའི་ Deep Seek ཞེས་རྒྱ་ནག་གི་མི་བཟོས་རིག་ནུས་མཉེན་ཆས་དེས་བོད་ནང་དུ་དེ་སྔ་ནས་དམ་དྲག་ཤུགས་ཆེ་ཡོད་པའི་ཐོག་ད་དུང་ཤུགས་ཆེ་རུ་གཏོང་སྲིད་པ་གསུངས་སོང་།"], ["མཉེན་ཆས་འདི་བཞིན་ཨ་རི་དང་རྒྱ་ནག་གཉིས་དབར་ཚོང་འབྲེལ་བཀག་སྡོམ་གྱི་གནད་དོན་ཁྲོད་ཨ་རིའི་མི་བཟོས་རིག་ནུས་ཀྱི་ Chips ཅིབ་སེ་མ་ལག་རྒྱ་ནག་ནང་དུ་ཚོང་འགྲེམ་བཀག་སྡོམ་བྱས་མིན་ལ་མ་ལྟོས་པར། ཚོང་འབྲེལ་བཀག་སྡོམ་གྱི་སྔོན་ཚུད་ནས་རྒྱ་ནག་གི་ཉོ་ཚོང་བྱས་པའི་ཅིབ་སེ་མ་ལག་དོན་ཐེངས་རྙིང་པའི་ཐོག་བཟོ་བསྐྲུན་བྱས་པ་དང་། ཨ་སྒོར་ཐེར་འབུམ་མང་པོའི་འགྲོ་གྲོན་ཐོག་བཟོ་བསྐྲུན་བྱས་པའི་ AI འམ་མི་བཟོས་རིག་ནུས་ཀྱི་མཉེན་ཆས་གཞན་དང་མི་འདྲ་བར་ Deep seek མཉེན་ཆས་དེ་བཞིན་ཨ་སྒོར་ས་ཡ་ ༦ ཁོ་ནའི་འགྲོ་གྲོན་ཐོག་བཟོ་བསྐྲུན་བྱས་པའི་གནད་དོན་སོགས་ཀྱི་རྐྱེན་པས་ཁ་སང་ཨ་རིའི་ཚོང་རའི་ནང་དུ་མི་བཟོས་རིག་ནུས་མཉེན་ཆས་འཕྲུལ་རིག་གི་ Chips ཅིབ་སེ་མ་ལག་བཟོ་བསྐྲུན་བྱས་མཁན་ NVidia ལྟ་བུར་ཨ་སྒོར་ཐེར་འབུམ་ ༦ མིན་ཙམ་གྱི་གྱོན་རྒུད་ཕོག་པའི་གནས་ཚུལ་བྱུང་ཡོད་འདུག"], ["དེ་ཡང་དེ་རིང་ BBC དང་ Reuters སོགས་རྒྱལ་སྤྱིའི་བརྒྱུད་ལམ་ཁག་གི་གནས་ཚུལ་སྤེལ་བར་གཞིགས་ན། རྒྱ་ནག་གི་ Huangzhou གྲོང་ཁྱེར་ནང་དུ་བཟོ་བསྐྲུན་བྱས་པའི་ Deep Seek མི་བཟོས་རིག་ནུས་མཉེན་ཆས་དེ་བཞིན་ ChatGPT དང་ Gemini སོགས་མི་བཟོས་རིག་ནུས་ཀྱི་མཉེན་ཆས་གཞན་དང་བསྡུར་ན་མགྱོགས་ཚད་དང་ནུས་པའི་ཆ་ནས་གཅིག་མཚུངས་ཡོད་པ་མ་ཟད། མཉེན་ཆས་དེ་ཉིད་རིན་མེད་ཡིན་པའི་ཆ་ནས་ཨ་རི་དང་ཨིན་ཡུལ། དེ་བཞིན་རྒྱ་ནག་གསུམ་གྱི་ནང་དུ་སྐུ་ཤུ་རྟགས་ཅན་འཕྲུལ་རིག་གི་ App Store གཉེན་ཆས་བངས་མཛོད་ནང་ནས་ Deep Seek དེ་ཉིད་མཉེན་ཆས་ཕབ་ལེན་མང་ཤོས་བྱས་པ་ཞིག་ཆགས་ཡོད་པ་རེད་འདུག"], ] gr.Examples( examples=examples, inputs=[input_box], label="Try these examples", examples_per_page=3 ) # Connect components with correct inputs submit_btn.click( fn=stream_translation, inputs=[input_box, chatbot], outputs=chatbot ) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)