chrisociepa commited on
Commit
46f82ea
1 Parent(s): b128b31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
  model_name = "Azurro/APT-1B-Base"
@@ -10,6 +11,7 @@ generator = pipeline(
10
  "text-generation",
11
  model=model,
12
  tokenizer=tokenizer,
 
13
  device_map="auto",
14
  )
15
 
@@ -24,7 +26,7 @@ def generate_text(prompt, max_length, temperature, top_k, top_p, beams):
24
  return output[0]['generated_text']
25
 
26
  input_text = gr.inputs.Textbox(label="Input Text")
27
- max_length = gr.inputs.Slider(1, 200, step=1, default=100, label="Max Length")
28
  temperature = gr.inputs.Slider(0.1, 1.0, step=0.1, default=0.8, label="Temperature")
29
  top_k = gr.inputs.Slider(1, 200, step=1, default=10, label="Top K")
30
  top_p = gr.inputs.Slider(0.1, 2.0, step=0.1, default=0.95, label="Top P")
@@ -32,4 +34,6 @@ beams = gr.inputs.Slider(1, 20, step=1, default=1, label="Beams")
32
 
33
  outputs = gr.outputs.Textbox(label="Generated Text")
34
 
35
- gr.Interface(generate_text, inputs=[input_text, max_length, temperature, top_k, top_p, beams], outputs=outputs).launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
  model_name = "Azurro/APT-1B-Base"
 
11
  "text-generation",
12
  model=model,
13
  tokenizer=tokenizer,
14
+ torch_dtype=torch.bfloat16,
15
  device_map="auto",
16
  )
17
 
 
26
  return output[0]['generated_text']
27
 
28
  input_text = gr.inputs.Textbox(label="Input Text")
29
+ max_length = gr.inputs.Slider(1, 100, step=1, default=30, label="Max Length")
30
  temperature = gr.inputs.Slider(0.1, 1.0, step=0.1, default=0.8, label="Temperature")
31
  top_k = gr.inputs.Slider(1, 200, step=1, default=10, label="Top K")
32
  top_p = gr.inputs.Slider(0.1, 2.0, step=0.1, default=0.95, label="Top P")
 
34
 
35
  outputs = gr.outputs.Textbox(label="Generated Text")
36
 
37
+ iface = gr.Interface(generate_text, inputs=[input_text, max_length, temperature, top_k, top_p, beams], outputs=outputs)
38
+ iface.queue(concurrency_count=1)
39
+ iface.launch(max_threads=100)