hoduyquocbao commited on
Commit
bf64382
1 Parent(s): d956981

new version update

Browse files
Files changed (3) hide show
  1. app.py +112 -67
  2. requirements.txt +7 -2
  3. style.css +11 -0
app.py CHANGED
@@ -1,92 +1,137 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
8
 
9
- import torch
10
- from transformers import pipeline
 
11
 
12
- model_id = "meta-llama/Llama-3.2-3B-Instruct"
13
- pipe = pipeline(
14
- "text-generation",
15
- model=model_id,
16
- torch_dtype=torch.bfloat16,
 
17
  device_map="auto",
 
18
  )
19
- messages = [
20
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
21
- {"role": "user", "content": "Who are you?"},
22
- ]
23
-
24
- # print(outputs[0]["generated_text"][-1])
25
-
26
- def respond(
27
- message,
28
- history: list[tuple[str, str]],
29
- system_message,
30
- max_tokens,
31
- temperature,
32
- top_p,
33
- ):
34
- messages = [{"role": "system", "content": system_message}]
35
-
36
- for val in history:
37
- if val[0]:
38
- messages.append({"role": "user", "content": val[0]})
39
- if val[1]:
40
- messages.append({"role": "assistant", "content": val[1]})
41
 
42
- messages.append({"role": "user", "content": message})
43
 
44
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # outputs = pipe(
47
- # messages,
48
- # max_new_tokens=256,
49
- # )
 
50
 
51
- # for message in client.chat_completion(
52
- # messages,
53
- # max_tokens=max_tokens,
54
- # stream=True,
55
- # temperature=temperature,
56
- # top_p=top_p,
57
- # ):
58
- for message in pipe(
59
- messages,
60
- max_tokens=max_tokens,
61
- stream=True,
62
- temperature=temperature,
63
  top_p=top_p,
64
- ):
65
- # token = message.choices[0].delta.content
66
- token = message[0]["generated_text"][-1]
 
 
 
 
67
 
68
- response += token
69
- yield response
 
 
70
 
71
- """
72
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
73
- """
74
- demo = gr.ChatInterface(
75
- respond,
76
  additional_inputs=[
77
- gr.Textbox(value="You are a friendly Chatbot RIM.", label="System message"),
78
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
79
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
80
  gr.Slider(
 
 
 
 
 
 
 
 
81
  minimum=0.1,
 
 
 
 
 
 
 
82
  maximum=1.0,
83
- value=0.95,
84
  step=0.05,
85
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
86
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ],
 
88
  )
89
 
 
 
 
 
90
 
91
  if __name__ == "__main__":
92
- demo.launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ DESCRIPTION = """\
11
+ # Llama 3.2 3B Instruct
12
+
13
+ Llama 3.2 3B is Meta's latest iteration of open LLMs.
14
+ This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
15
+ For more details, please check [our post](https://huggingface.co/blog/llama32).
16
  """
 
 
 
17
 
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
 
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ model_id = "nltpt/Llama-3.2-3B-Instruct"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
  device_map="auto",
29
+ torch_dtype=torch.bfloat16,
30
  )
31
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
33
 
34
+ @spaces.GPU(duration=90)
35
+ def generate(
36
+ message: str,
37
+ chat_history: list[tuple[str, str]],
38
+ max_new_tokens: int = 1024,
39
+ temperature: float = 0.6,
40
+ top_p: float = 0.9,
41
+ top_k: int = 50,
42
+ repetition_penalty: float = 1.2,
43
+ ) -> Iterator[str]:
44
+ conversation = []
45
+ for user, assistant in chat_history:
46
+ conversation.extend(
47
+ [
48
+ {"role": "user", "content": user},
49
+ {"role": "assistant", "content": assistant},
50
+ ]
51
+ )
52
+ conversation.append({"role": "user", "content": message})
53
 
54
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
55
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
+ input_ids = input_ids.to(model.device)
59
 
60
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
61
+ generate_kwargs = dict(
62
+ {"input_ids": input_ids},
63
+ streamer=streamer,
64
+ max_new_tokens=max_new_tokens,
65
+ do_sample=True,
 
 
 
 
 
 
66
  top_p=top_p,
67
+ top_k=top_k,
68
+ temperature=temperature,
69
+ num_beams=1,
70
+ repetition_penalty=repetition_penalty,
71
+ )
72
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
73
+ t.start()
74
 
75
+ outputs = []
76
+ for text in streamer:
77
+ outputs.append(text)
78
+ yield "".join(outputs)
79
 
80
+
81
+ chat_interface = gr.ChatInterface(
82
+ fn=generate,
 
 
83
  additional_inputs=[
 
 
 
84
  gr.Slider(
85
+ label="Max new tokens",
86
+ minimum=1,
87
+ maximum=MAX_MAX_NEW_TOKENS,
88
+ step=1,
89
+ value=DEFAULT_MAX_NEW_TOKENS,
90
+ ),
91
+ gr.Slider(
92
+ label="Temperature",
93
  minimum=0.1,
94
+ maximum=4.0,
95
+ step=0.1,
96
+ value=0.6,
97
+ ),
98
+ gr.Slider(
99
+ label="Top-p (nucleus sampling)",
100
+ minimum=0.05,
101
  maximum=1.0,
 
102
  step=0.05,
103
+ value=0.9,
104
+ ),
105
+ gr.Slider(
106
+ label="Top-k",
107
+ minimum=1,
108
+ maximum=1000,
109
+ step=1,
110
+ value=50,
111
  ),
112
+ gr.Slider(
113
+ label="Repetition penalty",
114
+ minimum=1.0,
115
+ maximum=2.0,
116
+ step=0.05,
117
+ value=1.2,
118
+ ),
119
+ ],
120
+ stop_btn=None,
121
+ examples=[
122
+ ["Hello there! How are you doing?"],
123
+ ["Can you explain briefly to me what is the Python programming language?"],
124
+ ["Explain the plot of Cinderella in a sentence."],
125
+ ["How many hours does it take a man to eat a Helicopter?"],
126
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
127
  ],
128
+ cache_examples=False,
129
  )
130
 
131
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
132
+ gr.Markdown(DESCRIPTION)
133
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
134
+ chat_interface.render()
135
 
136
  if __name__ == "__main__":
137
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
- huggingface_hub==0.22.2
2
- transformers
 
 
 
 
 
 
1
+ huggingface_hub
2
+ transformers
3
+ accelerate
4
+ bitsandbytes
5
+ gradio
6
+ spaces
7
+ torch
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }