hayas commited on
Commit
c44e6d0
·
1 Parent(s): 3b0f4d3
.pre-commit-config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ rev: v0.8.4
18
+ hooks:
19
+ - id: ruff
20
+ args: ["--fix"]
21
+ - id: ruff-format
22
+ args: ["--line-length", "119"]
23
+ - repo: https://github.com/pre-commit/mirrors-mypy
24
+ rev: v1.14.0
25
+ hooks:
26
+ - id: mypy
27
+ args: ["--ignore-missing-imports"]
28
+ additional_dependencies:
29
+ [
30
+ "types-python-slugify",
31
+ "types-requests",
32
+ "types-PyYAML",
33
+ "types-pytz",
34
+ ]
35
+ - repo: https://github.com/kynan/nbstripout
36
+ rev: 0.8.1
37
+ hooks:
38
+ - id: nbstripout
39
+ args:
40
+ [
41
+ "--extra-keys",
42
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
43
+ ]
44
+ - repo: https://github.com/nbQA-dev/nbQA
45
+ rev: 1.9.1
46
+ hooks:
47
+ - id: nbqa-black
48
+ - id: nbqa-pyupgrade
49
+ args: ["--py37-plus"]
50
+ - id: nbqa-isort
51
+ args: ["--float-to-top"]
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
.vscode/extensions.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "recommendations": [
3
+ "ms-python.python",
4
+ "charliermarsh.ruff",
5
+ "streetsidesoftware.code-spell-checker",
6
+ "tamasfe.even-better-toml"
7
+ ]
8
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "charliermarsh.ruff",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.fixAll.ruff": "explicit",
9
+ "source.organizeImports": "explicit"
10
+ }
11
+ },
12
+ "[jupyter]": {
13
+ "files.insertFinalNewline": false
14
+ },
15
+ "notebook.output.scrolling": true,
16
+ "notebook.formatOnCellExecution": true,
17
+ "notebook.formatOnSave.enabled": true,
18
+ "notebook.codeActionsOnSave": {
19
+ "source.organizeImports": "explicit"
20
+ }
21
+ }
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Llama 3.1 Swallow 8B Instruct V0.3
3
- emoji: 🌍
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
 
1
  ---
2
  title: Llama 3.1 Swallow 8B Instruct V0.3
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# Llama 3.1 Swallow 8B Instruct V0.3"
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
24
+ model.eval()
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+
27
+
28
+ @spaces.GPU
29
+ @torch.inference_mode()
30
+ def generate(
31
+ message: str,
32
+ chat_history: list[dict],
33
+ system_prompt: str = "",
34
+ max_new_tokens: int = 1024,
35
+ temperature: float = 0.6,
36
+ top_p: float = 0.9,
37
+ top_k: int = 50,
38
+ repetition_penalty: float = 1.0,
39
+ ) -> Iterator[str]:
40
+ conversation = []
41
+ if system_prompt:
42
+ conversation.append({"role": "system", "content": system_prompt})
43
+ conversation += chat_history
44
+ conversation.append({"role": "user", "content": message})
45
+
46
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
47
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
48
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
49
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
50
+ input_ids = input_ids.to(model.device)
51
+
52
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
53
+ generate_kwargs = dict(
54
+ {"input_ids": input_ids},
55
+ streamer=streamer,
56
+ max_new_tokens=max_new_tokens,
57
+ do_sample=True,
58
+ top_p=top_p,
59
+ top_k=top_k,
60
+ temperature=temperature,
61
+ num_beams=1,
62
+ repetition_penalty=repetition_penalty,
63
+ pad_token_id=tokenizer.eos_token_id,
64
+ )
65
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
66
+ t.start()
67
+
68
+ outputs = []
69
+ for text in streamer:
70
+ outputs.append(text)
71
+ yield "".join(outputs)
72
+
73
+
74
+ demo = gr.ChatInterface(
75
+ fn=generate,
76
+ additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
77
+ additional_inputs=[
78
+ gr.Textbox(label="System prompt", value="あなたは誠実で優秀な日本人のアシスタントです。"),
79
+ gr.Slider(
80
+ label="Max new tokens",
81
+ minimum=1,
82
+ maximum=MAX_MAX_NEW_TOKENS,
83
+ step=1,
84
+ value=DEFAULT_MAX_NEW_TOKENS,
85
+ ),
86
+ gr.Slider(
87
+ label="Temperature",
88
+ minimum=0.1,
89
+ maximum=2.0,
90
+ step=0.1,
91
+ value=0.6,
92
+ ),
93
+ gr.Slider(
94
+ label="Top-p (nucleus sampling)",
95
+ minimum=0.05,
96
+ maximum=1.0,
97
+ step=0.05,
98
+ value=0.9,
99
+ ),
100
+ gr.Slider(
101
+ label="Top-k",
102
+ minimum=1,
103
+ maximum=1000,
104
+ step=1,
105
+ value=50,
106
+ ),
107
+ gr.Slider(
108
+ label="Repetition penalty",
109
+ minimum=1.0,
110
+ maximum=2.0,
111
+ step=0.05,
112
+ value=1.0,
113
+ ),
114
+ ],
115
+ stop_btn=None,
116
+ examples=[
117
+ [
118
+ "東京の紅葉した公園で、東京タワーと高層ビルを背景に、空を舞うツバメと草地に佇むラマが出会う温かな物語を書いてください。"
119
+ ],
120
+ ],
121
+ type="messages",
122
+ description=DESCRIPTION,
123
+ css_paths="style.css",
124
+ fill_height=True,
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "llama-3-1-swallow-8b-instruct-v0-3"
3
+ version = "0.1.0"
4
+ description = ""
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=1.2.1",
9
+ "gradio>=5.9.1",
10
+ "spaces>=0.31.1",
11
+ "torch==2.4.0",
12
+ "transformers>=4.47.1",
13
+ ]
14
+
15
+ [tool.ruff]
16
+ line-length = 119
17
+
18
+ [tool.ruff.lint]
19
+ select = ["ALL"]
20
+ ignore = [
21
+ "COM812", # missing-trailing-comma
22
+ "D203", # one-blank-line-before-class
23
+ "D213", # multi-line-summary-second-line
24
+ "E501", # line-too-long
25
+ "SIM117", # multiple-with-statements
26
+ ]
27
+ extend-ignore = [
28
+ "D100", # undocumented-public-module
29
+ "D101", # undocumented-public-class
30
+ "D102", # undocumented-public-method
31
+ "D103", # undocumented-public-function
32
+ "D104", # undocumented-public-package
33
+ "D105", # undocumented-magic-method
34
+ "D107", # undocumented-public-init
35
+ "EM101", # raw-string-in-exception
36
+ "FBT001", # boolean-type-hint-positional-argument
37
+ "FBT002", # boolean-default-value-positional-argument
38
+ "PD901", # pandas-df-variable-name
39
+ "PGH003", # blanket-type-ignore
40
+ "PLR0913", # too-many-arguments
41
+ "PLR0915", # too-many-statements
42
+ "TRY003", # raise-vanilla-args
43
+ ]
44
+ unfixable = [
45
+ "F401", # unused-import
46
+ ]
47
+
48
+ [tool.ruff.format]
49
+ docstring-code-format = true
requirements.txt ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.2.1
4
+ # via llama-3-1-swallow-8b-instruct-v0-3 (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.7.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2024.12.14
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.4.1
20
+ # via requests
21
+ click==8.1.8
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.6
28
+ # via gradio
29
+ ffmpy==0.5.0
30
+ # via gradio
31
+ filelock==3.16.1
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2024.12.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.9.1
43
+ # via
44
+ # llama-3-1-swallow-8b-instruct-v0-3 (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.5.2
47
+ # via gradio
48
+ h11==0.14.0
49
+ # via
50
+ # httpcore
51
+ # uvicorn
52
+ httpcore==1.0.7
53
+ # via httpx
54
+ httpx==0.28.1
55
+ # via
56
+ # gradio
57
+ # gradio-client
58
+ # safehttpx
59
+ # spaces
60
+ huggingface-hub==0.27.0
61
+ # via
62
+ # accelerate
63
+ # gradio
64
+ # gradio-client
65
+ # tokenizers
66
+ # transformers
67
+ idna==3.10
68
+ # via
69
+ # anyio
70
+ # httpx
71
+ # requests
72
+ jinja2==3.1.5
73
+ # via
74
+ # gradio
75
+ # torch
76
+ markdown-it-py==3.0.0
77
+ # via rich
78
+ markupsafe==2.1.5
79
+ # via
80
+ # gradio
81
+ # jinja2
82
+ mdurl==0.1.2
83
+ # via markdown-it-py
84
+ mpmath==1.3.0
85
+ # via sympy
86
+ networkx==3.4.2
87
+ # via torch
88
+ numpy==2.2.1
89
+ # via
90
+ # accelerate
91
+ # gradio
92
+ # pandas
93
+ # transformers
94
+ nvidia-cublas-cu12==12.1.3.1
95
+ # via
96
+ # nvidia-cudnn-cu12
97
+ # nvidia-cusolver-cu12
98
+ # torch
99
+ nvidia-cuda-cupti-cu12==12.1.105
100
+ # via torch
101
+ nvidia-cuda-nvrtc-cu12==12.1.105
102
+ # via torch
103
+ nvidia-cuda-runtime-cu12==12.1.105
104
+ # via torch
105
+ nvidia-cudnn-cu12==9.1.0.70
106
+ # via torch
107
+ nvidia-cufft-cu12==11.0.2.54
108
+ # via torch
109
+ nvidia-curand-cu12==10.3.2.106
110
+ # via torch
111
+ nvidia-cusolver-cu12==11.4.5.107
112
+ # via torch
113
+ nvidia-cusparse-cu12==12.1.0.106
114
+ # via
115
+ # nvidia-cusolver-cu12
116
+ # torch
117
+ nvidia-nccl-cu12==2.20.5
118
+ # via torch
119
+ nvidia-nvjitlink-cu12==12.6.85
120
+ # via
121
+ # nvidia-cusolver-cu12
122
+ # nvidia-cusparse-cu12
123
+ nvidia-nvtx-cu12==12.1.105
124
+ # via torch
125
+ orjson==3.10.13
126
+ # via gradio
127
+ packaging==24.2
128
+ # via
129
+ # accelerate
130
+ # gradio
131
+ # gradio-client
132
+ # huggingface-hub
133
+ # spaces
134
+ # transformers
135
+ pandas==2.2.3
136
+ # via gradio
137
+ pillow==11.1.0
138
+ # via gradio
139
+ psutil==5.9.8
140
+ # via
141
+ # accelerate
142
+ # spaces
143
+ pydantic==2.10.4
144
+ # via
145
+ # fastapi
146
+ # gradio
147
+ # spaces
148
+ pydantic-core==2.27.2
149
+ # via pydantic
150
+ pydub==0.25.1
151
+ # via gradio
152
+ pygments==2.18.0
153
+ # via rich
154
+ python-dateutil==2.9.0.post0
155
+ # via pandas
156
+ python-multipart==0.0.20
157
+ # via gradio
158
+ pytz==2024.2
159
+ # via pandas
160
+ pyyaml==6.0.2
161
+ # via
162
+ # accelerate
163
+ # gradio
164
+ # huggingface-hub
165
+ # transformers
166
+ regex==2024.11.6
167
+ # via transformers
168
+ requests==2.32.3
169
+ # via
170
+ # huggingface-hub
171
+ # spaces
172
+ # transformers
173
+ rich==13.9.4
174
+ # via typer
175
+ ruff==0.8.5
176
+ # via gradio
177
+ safehttpx==0.1.6
178
+ # via gradio
179
+ safetensors==0.4.5
180
+ # via
181
+ # accelerate
182
+ # transformers
183
+ semantic-version==2.10.0
184
+ # via gradio
185
+ shellingham==1.5.4
186
+ # via typer
187
+ six==1.17.0
188
+ # via python-dateutil
189
+ sniffio==1.3.1
190
+ # via anyio
191
+ spaces==0.31.1
192
+ # via llama-3-1-swallow-8b-instruct-v0-3 (pyproject.toml)
193
+ starlette==0.41.3
194
+ # via
195
+ # fastapi
196
+ # gradio
197
+ sympy==1.13.3
198
+ # via torch
199
+ tokenizers==0.21.0
200
+ # via transformers
201
+ tomlkit==0.13.2
202
+ # via gradio
203
+ torch==2.4.0
204
+ # via
205
+ # llama-3-1-swallow-8b-instruct-v0-3 (pyproject.toml)
206
+ # accelerate
207
+ tqdm==4.67.1
208
+ # via
209
+ # huggingface-hub
210
+ # transformers
211
+ transformers==4.47.1
212
+ # via llama-3-1-swallow-8b-instruct-v0-3 (pyproject.toml)
213
+ triton==3.0.0
214
+ # via torch
215
+ typer==0.15.1
216
+ # via gradio
217
+ typing-extensions==4.12.2
218
+ # via
219
+ # anyio
220
+ # fastapi
221
+ # gradio
222
+ # gradio-client
223
+ # huggingface-hub
224
+ # pydantic
225
+ # pydantic-core
226
+ # rich
227
+ # spaces
228
+ # torch
229
+ # typer
230
+ # uvicorn
231
+ tzdata==2024.2
232
+ # via pandas
233
+ urllib3==2.3.0
234
+ # via requests
235
+ uvicorn==0.34.0
236
+ # via gradio
237
+ websockets==14.1
238
+ # via gradio-client
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: white;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff