MartinHeHeHe commited on
Commit
692c83c
·
1 Parent(s): 4bf5c3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -127
app.py CHANGED
@@ -1,141 +1,57 @@
1
  import os
2
- from threading import Thread
3
  from typing import Iterator
4
 
5
- import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
- MAX_MAX_NEW_TOKENS = 2048
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
- DESCRIPTION = """\
15
- # Llama-2 7B Chat
16
- This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
17
- 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
18
- 🔨 Looking for an even more powerful model? Check out the [13B version](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat) or the large [70B model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
19
- """
20
 
21
- LICENSE = """
22
- <p/>
23
- ---
24
- As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
25
- this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
26
- """
27
-
28
- if not torch.cuda.is_available():
29
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
30
-
31
-
32
- if torch.cuda.is_available():
33
- model_id = "meta-llama/Llama-2-7b-chat-hf"
34
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- tokenizer.use_default_system_prompt = False
37
-
38
-
39
- @spaces.GPU
40
- def generate(
41
- message: str,
42
- chat_history: list[tuple[str, str]],
43
- system_prompt: str,
44
- max_new_tokens: int = 1024,
45
- temperature: float = 0.6,
46
- top_p: float = 0.9,
47
- top_k: int = 50,
48
- repetition_penalty: float = 1.2,
49
- ) -> Iterator[str]:
50
- conversation = []
51
- if system_prompt:
52
- conversation.append({"role": "system", "content": system_prompt})
53
- for user, assistant in chat_history:
54
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
55
- conversation.append({"role": "user", "content": message})
56
-
57
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
58
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
59
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
60
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
61
- input_ids = input_ids.to(model.device)
62
 
63
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
64
  generate_kwargs = dict(
65
- {"input_ids": input_ids},
66
- streamer=streamer,
67
  max_new_tokens=max_new_tokens,
68
  do_sample=True,
69
  top_p=top_p,
70
  top_k=top_k,
71
  temperature=temperature,
72
- num_beams=1,
73
- repetition_penalty=repetition_penalty,
74
  )
75
- t = Thread(target=model.generate, kwargs=generate_kwargs)
76
- t.start()
77
-
78
- outputs = []
79
- for text in streamer:
80
- outputs.append(text)
81
- yield "".join(outputs)
82
-
83
-
84
- chat_interface = gr.ChatInterface(
85
- fn=generate,
86
- additional_inputs=[
87
- gr.Textbox(label="System prompt", lines=6),
88
- gr.Slider(
89
- label="Max new tokens",
90
- minimum=1,
91
- maximum=MAX_MAX_NEW_TOKENS,
92
- step=1,
93
- value=DEFAULT_MAX_NEW_TOKENS,
94
- ),
95
- gr.Slider(
96
- label="Temperature",
97
- minimum=0.1,
98
- maximum=4.0,
99
- step=0.1,
100
- value=0.6,
101
- ),
102
- gr.Slider(
103
- label="Top-p (nucleus sampling)",
104
- minimum=0.05,
105
- maximum=1.0,
106
- step=0.05,
107
- value=0.9,
108
- ),
109
- gr.Slider(
110
- label="Top-k",
111
- minimum=1,
112
- maximum=1000,
113
- step=1,
114
- value=50,
115
- ),
116
- gr.Slider(
117
- label="Repetition penalty",
118
- minimum=1.0,
119
- maximum=2.0,
120
- step=0.05,
121
- value=1.2,
122
- ),
123
- ],
124
- stop_btn=None,
125
- examples=[
126
- ["Hello there! How are you doing?"],
127
- ["Can you explain briefly to me what is the Python programming language?"],
128
- ["Explain the plot of Cinderella in a sentence."],
129
- ["How many hours does it take a man to eat a Helicopter?"],
130
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
131
- ],
132
- )
133
-
134
- with gr.Blocks(css="style.css") as demo:
135
- gr.Markdown(DESCRIPTION)
136
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
137
- chat_interface.render()
138
- gr.Markdown(LICENSE)
139
-
140
- if __name__ == "__main__":
141
- demo.queue(max_size=20).launch(server_port=4444)
 
1
  import os
 
2
  from typing import Iterator
3
 
4
+ from text_generation import Client
 
 
 
5
 
6
+ model_id = 'HuggingFaceH4/zephyr-7b-beta'
 
 
7
 
8
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN", None)
 
 
 
 
10
 
11
+ client = Client(
12
+ API_URL,
13
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
14
+ )
15
+ EOS_STRING = "</s>"
16
+ EOT_STRING = "<EOT>"
17
+
18
+
19
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
20
+ system_prompt: str) -> str:
21
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
22
+ # The first user input is _not_ stripped
23
+ do_strip = False
24
+ for user_input, response in chat_history:
25
+ user_input = user_input.strip() if do_strip else user_input
26
+ do_strip = True
27
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
28
+ message = message.strip() if do_strip else message
29
+ texts.append(f'{message} [/INST]')
30
+ return ''.join(texts)
31
+
32
+
33
+ def run(message: str,
34
+ chat_history: list[tuple[str, str]],
35
+ system_prompt: str,
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.1,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50) -> Iterator[str]:
40
+ prompt = get_prompt(message, chat_history, system_prompt)
 
 
 
 
 
 
 
 
 
 
 
41
 
 
42
  generate_kwargs = dict(
 
 
43
  max_new_tokens=max_new_tokens,
44
  do_sample=True,
45
  top_p=top_p,
46
  top_k=top_k,
47
  temperature=temperature,
 
 
48
  )
49
+ stream = client.generate_stream(prompt, **generate_kwargs)
50
+ output = ""
51
+ for response in stream:
52
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
53
+ return output
54
+ else:
55
+ output += response.token.text
56
+ yield output
57
+ return output