chrisociepa commited on
Commit
794a411
1 Parent(s): 46f82ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -32
app.py CHANGED
@@ -1,39 +1,61 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
4
 
5
- model_name = "Azurro/APT-1B-Base"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- generator = pipeline(
11
- "text-generation",
12
- model=model,
13
- tokenizer=tokenizer,
14
- torch_dtype=torch.bfloat16,
15
- device_map="auto",
16
- )
17
-
18
- def generate_text(prompt, max_length, temperature, top_k, top_p, beams):
19
- output = generator(prompt,
20
- max_length=max_length,
21
- temperature=temperature,
22
- top_k=top_k,
23
- do_sample=True,
24
- top_p=top_p,
25
- num_beams=beams)
26
- return output[0]['generated_text']
27
-
28
- input_text = gr.inputs.Textbox(label="Input Text")
29
- max_length = gr.inputs.Slider(1, 100, step=1, default=30, label="Max Length")
30
- temperature = gr.inputs.Slider(0.1, 1.0, step=0.1, default=0.8, label="Temperature")
31
- top_k = gr.inputs.Slider(1, 200, step=1, default=10, label="Top K")
32
- top_p = gr.inputs.Slider(0.1, 2.0, step=0.1, default=0.95, label="Top P")
33
- beams = gr.inputs.Slider(1, 20, step=1, default=1, label="Beams")
 
 
 
 
 
 
34
 
35
- outputs = gr.outputs.Textbox(label="Generated Text")
36
-
37
- iface = gr.Interface(generate_text, inputs=[input_text, max_length, temperature, top_k, top_p, beams], outputs=outputs)
38
- iface.queue(concurrency_count=1)
39
- iface.launch(max_threads=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import time
4
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, pipeline
5
 
6
+ model_name = "Azurro/APT3-1B-Instruct-v1"
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = LlamaForCausalLM.from_pretrained(model_name, torch.float16)
10
 
11
+ def generate_text(prompt, max_length, temperature, top_k, top_p):
12
+ prompt = f'<s>[INST] {prompt.strip()} [/INST]'
13
+ input_ids = tokenizer(prompt, return_tensors='pt', add_special_tokens=False).input_ids.to(model.device)
14
+ start_time = time.time()
15
+ output = model.generate(
16
+ inputs=input_ids,
17
+ max_new_tokens=max_length,
18
+ temperature=temperature,
19
+ top_k=top_k,
20
+ do_sample=(temperature > 0),
21
+ top_p=top_p,
22
+ num_beams=1,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ pad_token_id=3,
26
+ repetition_penalty=1.1
27
+ )
28
+ elapsed_time = time.time() - start_time
29
+ decoded_output = tokenizer.decode(output[0])
30
+ input_tokens_count = len(input_ids[0])
31
+ input_chars_count = len(prompt)
32
+ output_tokens_count = len(output[0])
33
+ output_chars_count = len(decoded_output)
34
+ gen_speed = output_tokens_count / elapsed_time
35
+ decoded_output = decoded_output[len(prompt):].replace('</s>','').strip()
36
+ print(f"Input tokens: {input_tokens_count} (chars: {input_chars_count}), Output tokens: {output_tokens_count} (chars: {output_chars_count}), Gen Time: {elapsed_time:.2f} secs ({gen_speed} toks/sec)")
37
+ print(f"{'*'*10} Input {'*'*10}\n{prompt}")
38
+ print(f"{'*'*10} Output {'*'*10}\n{prompt}")
39
+ print(f"{'*'*30}")
40
+ return decoded_output, input_tokens_count, input_chars_count, output_tokens_count, output_chars_count, gen_speed
41
 
42
+ demo = gr.Interface(
43
+ fn=generate_text,
44
+ inputs=[
45
+ gr.inputs.Textbox(label="Input Text"),
46
+ gr.inputs.Slider(1, 1000, step=1, default=100, label="Max Length"),
47
+ gr.inputs.Slider(0.0, 1.5, step=0.1, default=0.6, label="Temperature"),
48
+ gr.inputs.Slider(1, 400, step=1, default=200, label="Top K"),
49
+ gr.inputs.Slider(0.0, 1.0, step=0.05, default=0.95, label="Top P")
50
+ ],
51
+ outputs=[
52
+ gr.outputs.Textbox(label="Generated Text"),
53
+ gr.outputs.Textbox(label="Input Tokens Count"),
54
+ gr.outputs.Textbox(label="Input Characters Count"),
55
+ gr.outputs.Textbox(label="Output Tokens Count"),
56
+ gr.outputs.Textbox(label="Output Characters Count"),
57
+ gr.outputs.Textbox(label="Generation speed in tokens per second"),
58
+ ]
59
+ )
60
+ demo.queue(concurrency_count=1)
61
+ demo.launch(max_threads=20)