Aashi commited on
Commit
7a44780
·
verified ·
1 Parent(s): 3afacd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -2
app.py CHANGED
@@ -1,5 +1,54 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- demo = gr.load("models/NSTiwari/fine_tuned_science_gemma2b-it")
 
 
 
 
 
 
 
4
 
5
- demo.launch()
 
 
1
+ # import gradio as gr
2
+
3
+ # demo = gr.load("models/NSTiwari/fine_tuned_science_gemma2b-it")
4
+
5
+ # demo.launch()
6
+
7
  import gradio as gr
8
+ import transformers
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ import torch
11
+ import time
12
+
13
+ # Replace with your fine-tuned model ID from Hugging Face Hub
14
+ model_id = "models/NSTiwari/fine_tuned_science_gemma2b-it"
15
+
16
+ # Load tokenizer and model
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
19
+
20
+ def inference(input_text):
21
+ """
22
+ Performs inference on the science question and returns answer and latency.
23
+ """
24
+ start_time = time.time()
25
+ input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
26
+ input_length = input_ids["input_ids"].shape[1]
27
+ outputs = model.generate(
28
+ input_ids=input_ids["input_ids"],
29
+ max_length=512, # Adjust max_length as needed
30
+ do_sample=False
31
+ )
32
+ generated_sequence = outputs[:, input_length:].tolist()
33
+ response = tokenizer.decode(generated_sequence[0])
34
+ end_time = time.time()
35
+ return {"answer": response, "latency": f"{end_time - start_time:.2f} seconds"}
36
+
37
+ def gradio_interface(question):
38
+ """
39
+ Gradio interface function that calls inference and returns answer/latency.
40
+ """
41
+ result = inference(question)
42
+ return result["answer"], result["latency"]
43
 
44
+ # Gradio interface definition
45
+ iface = gr.Interface(
46
+ fn=gradio_interface,
47
+ inputs=gr.Textbox(label="Science Question", lines=4),
48
+ outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Latency")],
49
+ title="Science Q&A with Fine-tuned Model",
50
+ description="Ask a science question and get an answer from the fine-tuned model.",
51
+ )
52
 
53
+ if __name__ == "__main__":
54
+ iface.launch()