TenzinGayche commited on
Commit
e9bec21
·
verified ·
1 Parent(s): 49495f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ from threading import Thread
5
+
6
+ # Load tokenizer and model
7
+ tokenizer = GemmaTokenizerFast.from_pretrained("buddhist-nlp/gemma2-mitra-bo-instruct")
8
+ model = AutoModelForCausalLM.from_pretrained("buddhist-nlp/gemma2-mitra-bo-instruct", torch_dtype=torch.float16).to('cuda:0')
9
+
10
+ # Define custom stopping criteria
11
+ class StopOnTokens(StoppingCriteria):
12
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
13
+ # Define stop tokens (adjust based on your model's tokenizer)
14
+ stop_ids = [29, 0] # These should be the token IDs for end of response or similar tokens
15
+ for stop_id in stop_ids:
16
+ if input_ids[0][-1] == stop_id:
17
+ return True
18
+ return False
19
+
20
+ # Define prediction function for the chat interface
21
+ def predict(message, history):
22
+ # Prepare the conversation in the required format
23
+ history_transformer_format = history + [[message, ""]]
24
+ stop = StopOnTokens()
25
+
26
+ # Concatenate previous messages and the user's input
27
+ messages = "".join([f"\n### user : {item[0]} \n### bot : {item[1]}" for item in history_transformer_format])
28
+
29
+ # Tokenize the input
30
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
31
+
32
+ # Set up the streamer for partial message output
33
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
34
+
35
+ # Generate settings
36
+ generate_kwargs = dict(
37
+ model_inputs,
38
+ streamer=streamer,
39
+ max_new_tokens=1024,
40
+ )
41
+
42
+ # Run generation in a separate thread
43
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
44
+ t.start()
45
+
46
+ # Stream partial messages as they are generated
47
+ partial_message = ""
48
+ for new_token in streamer:
49
+ if new_token != '<': # Skip specific tokens if necessary
50
+ partial_message += new_token
51
+ yield partial_message
52
+
53
+ # Create the chat interface using Gradio
54
+ gr.ChatInterface(fn=predict, title="Gemma LLM Chatbot", description="Chat with the Gemma model using real-time generation and streaming.").launch(share=True)