File size: 8,076 Bytes
88f1511
 
 
 
 
 
 
 
f0c45b2
88f1511
f0c45b2
88f1511
 
672cd19
88f1511
672cd19
88f1511
8d5955c
e02c5de
 
88f1511
 
 
672cd19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e02c5de
672cd19
 
 
88f1511
 
 
 
 
 
672cd19
 
 
bd11797
 
 
 
 
4270309
830ed7d
672cd19
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a572663
 
 
 
 
 
 
 
 
 
 
 
c5710ac
 
88f1511
 
 
a572663
88f1511
a572663
88f1511
a572663
 
88f1511
a572663
88f1511
a572663
 
 
 
 
c5710ac
a572663
 
c5710ac
 
 
 
 
 
 
 
e02c5de
c5710ac
 
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a572663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88f1511
 
a572663
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a572663
88f1511
 
 
 
 
 
a572663
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
a572663
88f1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672cd19
88f1511
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 256
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 10240

DESCRIPTION = """\
# CLEX-7B-Chat-16K

This Space demonstrates model [CLEX-7B-Chat-16K](https://huggingface.co./DAMO-NLP-SG/CLEX-7B-Chat-16K), a Llama-2-7B model fine-tuned using our [CLEX](https://arxiv.org/abs/2310.16450) method. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co./inference-endpoints).

The web demo supports the maximum input sequence length of 10k now due to the limit of GPU memory, running the demo locally (with larger GPU memory) is highly recommended.

This support of PDF input is tentative.

"""

# LICENSE = """
# <p/>

# ---
# As a derivate work of [Llama-2-7b-chat](https://huggingface.co./meta-llama/Llama-2-7b-chat) by Meta,
# this demo is governed by the original [license](https://huggingface.co./spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co./spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
# """


CITE = """
If you find our project useful, hope you can star our repo and cite our paper as follows:
```
@article{damonlpsg2023clex,
  author = {Chen, Guanzheng and Li, Xin and Meng, Zaiqiao and Liang, Shangsong and Bing, Lidong},
  title = {CLEX: Continuous Length Extrapolation for Large Language Models},
  year = 2023,
  journal = {arXiv preprint arXiv:2310.16450},
  url = {https://arxiv.org/abs/2310.16450}
}
```
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


# if torch.cuda.is_available():
model_id = "DAMO-NLP-SG/CLEX-7b-Chat-16K"
from transformers import AutoModelForCausalLM
from modeling_llama import LlamaForCausalLM
# from configuration_clex import CLEXLlamaConfig
# config = CLEXLlamaConfig.from_pretrained(
#         model_id
#     )
model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, low_cpu_mem_usage=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
tokenizer.use_default_system_prompt = False

import PyPDF2
from io import BytesIO

def process_pdf(input_pdf):
    # Read the binary data from the input_pdf
    # pdf_data = BytesIO(input_pdf)
    # if pdf_data.getvalue().strip() == b'':
    #     return ""
    # Create a PDF reader object
    reader = PyPDF2.PdfReader(input_pdf.name)
    # Extract the text from each page of the PDF
    text = ""
    for page in reader.pages:
        text += page.extract_text()
    # Close the PDF reader and reset the pointer
    # reader.close()
    # pdf_data.seek(0)
    # Return the extracted text
    return text



def build_chat():
    from fastchat.model import get_conversation_template
    conv = get_conversation_template("vicuna")
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt() 
    return prompt


import re

def replace_repeated_spaces_and_newlines(text):
    # Replace repeated spaces with a single space
    text = re.sub(r'\s+', ' ', text)

    # Replace repeated newlines with a single newline
    text = re.sub(r'\n+', '\n', text)

    return text

from fastchat.model import get_conversation_template

@spaces.GPU
def generate(
    message: str,
    chat_history,
    system_prompt: str,
    input_pdf: BytesIO = None,
    max_new_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 1.0,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    if input_pdf is not None:
        pdf_text = process_pdf(input_pdf)
        # print(pdf_text)
        pdf_text = replace_repeated_spaces_and_newlines(pdf_text)
        message += f"\nThis is the beginning of a pdf\n{pdf_text}This is the end of a pdf\n"
    conv = get_conversation_template("vicuna")
    if system_prompt is not None:
        conv.set_system_message(system_prompt)
    conv.append_message(conv.roles[0], message)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    # if system_prompt:
    #     conversation.append({"role": "system", "content": system_prompt})
    # for user, assistant in chat_history:
    #     conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    # conversation.append({"role": "user", "content": message})
    # print(prompt[500:1000])
    # chat = tokenizer.apply_chat_template(conversation, tokenize=False)
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
    if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
        inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


# def generate_with_pdf(
#     message: str,
#     chat_history,
#     system_prompt: str,
#     input_pdf: BytesIO = None,
#     max_new_tokens: int = 1024,
#     temperature: float = 0.6,
#     top_p: float = 0.9,
#     top_k: int = 50,
#     repetition_penalty: float = 1.2,
# ) -> Iterator[str]:
#     if input_pdf is not None:
#         pdf_text = process_pdf(input_pdf)
#         # print(pdf_text)
#         message += f"\nThis is the beginning of a pdf\n{pdf_text}This is the end of a pdf\n"
#     yield from generate(
#         message,
#         chat_history,
#         system_prompt,
#         max_new_tokens,
#         temperature,
#         top_p,
#         top_k,
#         repetition_penalty
#     )

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.File(label="PDF File", accept=".pdf"),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=1.0,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)



with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")

    chat_interface.render()
    gr.Markdown(CITE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch(share=False)