hayas commited on
Commit
1b50cd3
1 Parent(s): 73e298a
Files changed (6) hide show
  1. .pre-commit-config.yaml +55 -0
  2. .vscode/settings.json +26 -0
  3. README.md +1 -1
  4. app.py +140 -0
  5. requirements.txt +8 -0
  6. style.css +10 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.7.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ ["types-python-slugify", "types-requests", "types-PyYAML"]
33
+ - repo: https://github.com/psf/black
34
+ rev: 23.12.0
35
+ hooks:
36
+ - id: black
37
+ language_version: python3.10
38
+ args: ["--line-length", "119"]
39
+ - repo: https://github.com/kynan/nbstripout
40
+ rev: 0.6.1
41
+ hooks:
42
+ - id: nbstripout
43
+ args:
44
+ [
45
+ "--extra-keys",
46
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
47
+ ]
48
+ - repo: https://github.com/nbQA-dev/nbQA
49
+ rev: 1.7.1
50
+ hooks:
51
+ - id: nbqa-black
52
+ - id: nbqa-pyupgrade
53
+ args: ["--py37-plus"]
54
+ - id: nbqa-isort
55
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true
26
+ }
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Swallow 13B
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: purple
 
1
  ---
2
+ title: Swallow-13B instruct
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: purple
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = """# Swallow-13B instruct"""
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ if torch.cuda.is_available():
18
+ model_name = "tokyotech-llm/Swallow-13b-instruct-hf"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_name, load_in_8bit=True, low_cpu_mem_usage=True, device_map="auto"
22
+ )
23
+
24
+ MAX_INPUT_TOKENS = 2048
25
+
26
+ PROMPT_DICT = {
27
+ "prompt_input": (
28
+ "以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
29
+ "リクエストを適切に完了するための回答を記述してください。\n\n"
30
+ "### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"
31
+ ),
32
+ "prompt_no_input": (
33
+ "以下に、あるタスクを説明する指示があります。" "リクエストを適切に完了するための回答を記述してください。\n\n" "### 指示:\n{instruction}\n\n### 応答:"
34
+ ),
35
+ }
36
+
37
+
38
+ def create_prompt(instruction: str, input_text: str | None = None) -> str:
39
+ """Generates a prompt based on the given instruction and an optional input.
40
+ If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
41
+ If no input is provided, it uses the 'prompt_no_input' template.
42
+
43
+ Args:
44
+ instruction (str): The instruction describing the task.
45
+ input_text (str, optional): Additional input providing context for the task. Default is None.
46
+
47
+ Returns:
48
+ str: The generated prompt.
49
+ """
50
+ if input_text:
51
+ # Use the 'prompt_input' template when additional input is provided
52
+ return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text)
53
+ else:
54
+ # Use the 'prompt_no_input' template when no additional input is provided
55
+ return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
56
+
57
+
58
+ @spaces.GPU
59
+ @torch.inference_mode()
60
+ def run(
61
+ instruction: str,
62
+ input_text: str | None = None,
63
+ max_new_tokens: int = 256,
64
+ temperature: float = 0.99,
65
+ top_p: float = 0.95,
66
+ ) -> Iterator[str]:
67
+ if input_text == "":
68
+ input_text = None
69
+
70
+ prompt = create_prompt(instruction, input_text)
71
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
72
+ if input_ids.shape[-1] > MAX_INPUT_TOKENS:
73
+ raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})")
74
+
75
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
76
+ generate_kwargs = dict(
77
+ {"input_ids": input_ids.to(model.device)},
78
+ streamer=streamer,
79
+ max_new_tokens=max_new_tokens,
80
+ temperature=temperature,
81
+ top_p=top_p,
82
+ do_sample=True,
83
+ )
84
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
85
+ t.start()
86
+
87
+ outputs = []
88
+ for text in streamer:
89
+ outputs.append(text)
90
+ yield "".join(outputs)
91
+
92
+
93
+ def process_example(instruction: str, input_text: str) -> Iterator[str]:
94
+ yield from run(instruction, input_text)
95
+
96
+
97
+ with gr.Blocks(css="style.css") as demo:
98
+ gr.Markdown(DESCRIPTION)
99
+ gr.DuplicateButton(
100
+ value="Duplicate Space for private use",
101
+ elem_id="duplicate-button",
102
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
103
+ )
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ instruction = gr.Textbox(label="Instruction", lines=5)
108
+ input_text = gr.Textbox(label="Input (optional)", lines=5)
109
+ run_button = gr.Button()
110
+
111
+ with gr.Accordion(label="Advanced Options", open=False):
112
+ max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256)
113
+ temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99)
114
+ top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95)
115
+
116
+ with gr.Column():
117
+ output = gr.Textbox(label="Output", lines=10)
118
+
119
+ run_button.click(
120
+ fn=run,
121
+ inputs=[instruction, input_text, max_new_tokens, temperature, top_p],
122
+ outputs=output,
123
+ api_name="run",
124
+ )
125
+
126
+ gr.Examples(
127
+ examples=[
128
+ ["以下のトピックに関する詳細な情報を提供してください。", "東京工業大学の主なキャンパスについて教えてください。"],
129
+ ["以下のトピックに関する詳細な情報を提供してください。", "夢オチとは何かについて教えてください。"],
130
+ ["暴れん坊将軍って誰のことですか?", ""],
131
+ ],
132
+ inputs=[instruction, input_text, max_new_tokens, temperature, top_p],
133
+ outputs=output,
134
+ fn=process_example,
135
+ cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
136
+ api_name=False,
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ bitsandbytes==0.41.2.post2
3
+ protobuf==4.25.1
4
+ scipy==1.11.4
5
+ sentencepiece==0.1.99
6
+ spaces==0.19.2
7
+ torch==2.0.0
8
+ transformers==4.36.2
style.css ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: #fff;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }