Saving train state of step 1000
Browse files
checkpoint-1000-epoch-0/model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3024943976
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0872c0c697bae76960041656d4e8354cc69b767929e06c0f35615b5e9bc6ed4c
|
3 |
size 3024943976
|
checkpoint-1000-epoch-0/optimizer.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 955529338
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:697196fece1aa67a773dc71ee5d622ea128185dbda531c1796ea76634c55d988
|
3 |
size 955529338
|
distil-whisper/events.out.tfevents.1705598735.c066756f484e.20967.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08d3c29719c5def41e6b7b23aae8e54d6851e848fa50a028b1940416efea4dee
|
3 |
+
size 12458
|
run_distillation.py
CHANGED
@@ -458,7 +458,7 @@ def log_pred(
|
|
458 |
):
|
459 |
"""Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
|
460 |
if accelerator.is_main_process:
|
461 |
-
wandb_tracker = accelerator.get_tracker("wandb")
|
462 |
# pretty name for current step: step 50000 -> step 50k
|
463 |
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
|
464 |
prefix_pretty = prefix.replace("/", "-")
|
@@ -466,23 +466,23 @@ def log_pred(
|
|
466 |
# convert str data to a wandb compatible format
|
467 |
str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
|
468 |
# log as a table with the appropriate headers
|
469 |
-
wandb_tracker.log_table(
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
)
|
475 |
|
476 |
# log incorrect normalised predictions
|
477 |
str_data = np.asarray(str_data)
|
478 |
str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
|
479 |
# log as a table with the appropriate headers
|
480 |
-
wandb_tracker.log_table(
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
)
|
486 |
|
487 |
|
488 |
def convert_dataset_str_to_list(
|
|
|
458 |
):
|
459 |
"""Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
|
460 |
if accelerator.is_main_process:
|
461 |
+
# wandb_tracker = accelerator.get_tracker("wandb")
|
462 |
# pretty name for current step: step 50000 -> step 50k
|
463 |
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
|
464 |
prefix_pretty = prefix.replace("/", "-")
|
|
|
466 |
# convert str data to a wandb compatible format
|
467 |
str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
|
468 |
# log as a table with the appropriate headers
|
469 |
+
# wandb_tracker.log_table(
|
470 |
+
# table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
|
471 |
+
# columns=["Target", "Pred", "Norm Target", "Norm Pred"],
|
472 |
+
# data=str_data[:num_lines],
|
473 |
+
# step=step,
|
474 |
+
# )
|
475 |
|
476 |
# log incorrect normalised predictions
|
477 |
str_data = np.asarray(str_data)
|
478 |
str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
|
479 |
# log as a table with the appropriate headers
|
480 |
+
# wandb_tracker.log_table(
|
481 |
+
# table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
|
482 |
+
# columns=["Target", "Pred", "Norm Target", "Norm Pred"],
|
483 |
+
# data=str_data_incorrect[:num_lines],
|
484 |
+
# step=step,
|
485 |
+
# )
|
486 |
|
487 |
|
488 |
def convert_dataset_str_to_list(
|