joaogante HF staff commited on
Commit
6508f4c
Β·
1 Parent(s): 6d8dfa2

fdsp config

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import multiprocessing as mp
3
 
@@ -54,7 +55,7 @@ DEFAULT_TRAINING_ARGS = \
54
  --auto_find_batch_size True"""
55
 
56
 
57
- def train_medusa_heads(model_id: str, training_args: str):
58
  all_training_args = FIXED_TRAINING_ARGS.format(
59
  model_id=model_id,
60
  output_dir=OUTPUT_DIR,
@@ -67,13 +68,16 @@ def train_medusa_heads(model_id: str, training_args: str):
67
  for arg in all_training_args.split("\n"):
68
  all_training_arg_list += arg.split(" ")
69
 
70
- parser = distributed_run.get_args_parser()
 
71
  print(all_training_arg_list)
 
 
72
  args = parser.parse_args(all_training_arg_list)
73
  distributed_run.run(args)
74
 
75
 
76
- def run(model_id: str, training_args: str) -> str:
77
  print(f"\n\n\nNEW RUN: {model_id}")
78
  api = HfApi()
79
  model_name = model_id.split("/")[-1]
@@ -110,10 +114,10 @@ def run(model_id: str, training_args: str) -> str:
110
 
111
  # Run the medusa heads creation
112
  try:
113
- proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args))
114
  proc.start()
115
  proc.join()
116
- print("Medusa heads training success βœ…")
117
  except Exception as e:
118
  print("Error ❌\n", e)
119
  return f"""
@@ -178,8 +182,14 @@ with gr.Blocks(title=title) as demo:
178
  with gr.Row() as r:
179
  with gr.Column() as c:
180
  model_id = gr.Text(max_lines=1, label="model_id")
181
- with gr.Accordion("Training arguments", open=False):
182
  training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=14, label="training_args")
 
 
 
 
 
 
183
  with gr.Row() as c:
184
  clean = gr.ClearButton()
185
  submit = gr.Button("Submit", variant="primary")
@@ -187,6 +197,6 @@ with gr.Blocks(title=title) as demo:
187
  with gr.Column() as d:
188
  status_box = gr.Markdown()
189
 
190
- submit.click(run, inputs=[model_id, training_args], outputs=status_box, concurrency_limit=1)
191
 
192
  demo.queue(max_size=10).launch(show_api=True)
 
1
+ import json
2
  import os
3
  import multiprocessing as mp
4
 
 
55
  --auto_find_batch_size True"""
56
 
57
 
58
+ def train_medusa_heads(model_id: str, training_args: str, fdsp_config: str):
59
  all_training_args = FIXED_TRAINING_ARGS.format(
60
  model_id=model_id,
61
  output_dir=OUTPUT_DIR,
 
68
  for arg in all_training_args.split("\n"):
69
  all_training_arg_list += arg.split(" ")
70
 
71
+ if fdsp_config != "":
72
+ all_training_arg_list += ["--fdsp_config", json.loads(fdsp_config)]
73
  print(all_training_arg_list)
74
+
75
+ parser = distributed_run.get_args_parser()
76
  args = parser.parse_args(all_training_arg_list)
77
  distributed_run.run(args)
78
 
79
 
80
+ def run(model_id: str, training_args: str, fdsp_config: str) -> str:
81
  print(f"\n\n\nNEW RUN: {model_id}")
82
  api = HfApi()
83
  model_name = model_id.split("/")[-1]
 
114
 
115
  # Run the medusa heads creation
116
  try:
117
+ proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, fdsp_config))
118
  proc.start()
119
  proc.join()
120
+ print("Medusa heads training process completed (it might have crashed!)")
121
  except Exception as e:
122
  print("Error ❌\n", e)
123
  return f"""
 
182
  with gr.Row() as r:
183
  with gr.Column() as c:
184
  model_id = gr.Text(max_lines=1, label="model_id")
185
+ with gr.Accordion("Training arguments (advanced)", open=False):
186
  training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=14, label="training_args")
187
+ fdsp_config = gr.Textbox(
188
+ placeholder="e.g. \'{\"value1\":\"key1\"}\'. leave empty if fdsp is not used. ",
189
+ interactive=True,
190
+ lines=1,
191
+ label="fdsp_config"
192
+ )
193
  with gr.Row() as c:
194
  clean = gr.ClearButton()
195
  submit = gr.Button("Submit", variant="primary")
 
197
  with gr.Column() as d:
198
  status_box = gr.Markdown()
199
 
200
+ submit.click(run, inputs=[model_id, training_args, fdsp_config], outputs=status_box, concurrency_limit=1)
201
 
202
  demo.queue(max_size=10).launch(show_api=True)