Spaces:
Paused
Paused
fdsp config
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import multiprocessing as mp
|
3 |
|
@@ -54,7 +55,7 @@ DEFAULT_TRAINING_ARGS = \
|
|
54 |
--auto_find_batch_size True"""
|
55 |
|
56 |
|
57 |
-
def train_medusa_heads(model_id: str, training_args: str):
|
58 |
all_training_args = FIXED_TRAINING_ARGS.format(
|
59 |
model_id=model_id,
|
60 |
output_dir=OUTPUT_DIR,
|
@@ -67,13 +68,16 @@ def train_medusa_heads(model_id: str, training_args: str):
|
|
67 |
for arg in all_training_args.split("\n"):
|
68 |
all_training_arg_list += arg.split(" ")
|
69 |
|
70 |
-
|
|
|
71 |
print(all_training_arg_list)
|
|
|
|
|
72 |
args = parser.parse_args(all_training_arg_list)
|
73 |
distributed_run.run(args)
|
74 |
|
75 |
|
76 |
-
def run(model_id: str, training_args: str) -> str:
|
77 |
print(f"\n\n\nNEW RUN: {model_id}")
|
78 |
api = HfApi()
|
79 |
model_name = model_id.split("/")[-1]
|
@@ -110,10 +114,10 @@ def run(model_id: str, training_args: str) -> str:
|
|
110 |
|
111 |
# Run the medusa heads creation
|
112 |
try:
|
113 |
-
proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args))
|
114 |
proc.start()
|
115 |
proc.join()
|
116 |
-
print("Medusa heads training
|
117 |
except Exception as e:
|
118 |
print("Error β\n", e)
|
119 |
return f"""
|
@@ -178,8 +182,14 @@ with gr.Blocks(title=title) as demo:
|
|
178 |
with gr.Row() as r:
|
179 |
with gr.Column() as c:
|
180 |
model_id = gr.Text(max_lines=1, label="model_id")
|
181 |
-
with gr.Accordion("Training arguments", open=False):
|
182 |
training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=14, label="training_args")
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
with gr.Row() as c:
|
184 |
clean = gr.ClearButton()
|
185 |
submit = gr.Button("Submit", variant="primary")
|
@@ -187,6 +197,6 @@ with gr.Blocks(title=title) as demo:
|
|
187 |
with gr.Column() as d:
|
188 |
status_box = gr.Markdown()
|
189 |
|
190 |
-
submit.click(run, inputs=[model_id, training_args], outputs=status_box, concurrency_limit=1)
|
191 |
|
192 |
demo.queue(max_size=10).launch(show_api=True)
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
import multiprocessing as mp
|
4 |
|
|
|
55 |
--auto_find_batch_size True"""
|
56 |
|
57 |
|
58 |
+
def train_medusa_heads(model_id: str, training_args: str, fdsp_config: str):
|
59 |
all_training_args = FIXED_TRAINING_ARGS.format(
|
60 |
model_id=model_id,
|
61 |
output_dir=OUTPUT_DIR,
|
|
|
68 |
for arg in all_training_args.split("\n"):
|
69 |
all_training_arg_list += arg.split(" ")
|
70 |
|
71 |
+
if fdsp_config != "":
|
72 |
+
all_training_arg_list += ["--fdsp_config", json.loads(fdsp_config)]
|
73 |
print(all_training_arg_list)
|
74 |
+
|
75 |
+
parser = distributed_run.get_args_parser()
|
76 |
args = parser.parse_args(all_training_arg_list)
|
77 |
distributed_run.run(args)
|
78 |
|
79 |
|
80 |
+
def run(model_id: str, training_args: str, fdsp_config: str) -> str:
|
81 |
print(f"\n\n\nNEW RUN: {model_id}")
|
82 |
api = HfApi()
|
83 |
model_name = model_id.split("/")[-1]
|
|
|
114 |
|
115 |
# Run the medusa heads creation
|
116 |
try:
|
117 |
+
proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, fdsp_config))
|
118 |
proc.start()
|
119 |
proc.join()
|
120 |
+
print("Medusa heads training process completed (it might have crashed!)")
|
121 |
except Exception as e:
|
122 |
print("Error β\n", e)
|
123 |
return f"""
|
|
|
182 |
with gr.Row() as r:
|
183 |
with gr.Column() as c:
|
184 |
model_id = gr.Text(max_lines=1, label="model_id")
|
185 |
+
with gr.Accordion("Training arguments (advanced)", open=False):
|
186 |
training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=14, label="training_args")
|
187 |
+
fdsp_config = gr.Textbox(
|
188 |
+
placeholder="e.g. \'{\"value1\":\"key1\"}\'. leave empty if fdsp is not used. ",
|
189 |
+
interactive=True,
|
190 |
+
lines=1,
|
191 |
+
label="fdsp_config"
|
192 |
+
)
|
193 |
with gr.Row() as c:
|
194 |
clean = gr.ClearButton()
|
195 |
submit = gr.Button("Submit", variant="primary")
|
|
|
197 |
with gr.Column() as d:
|
198 |
status_box = gr.Markdown()
|
199 |
|
200 |
+
submit.click(run, inputs=[model_id, training_args, fdsp_config], outputs=status_box, concurrency_limit=1)
|
201 |
|
202 |
demo.queue(max_size=10).launch(show_api=True)
|