Fix some bugs
Browse files
config.json
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"T5ForConditionalGeneration"
|
4 |
-
],
|
5 |
-
"d_ff": 3072,
|
6 |
-
"d_kv": 64,
|
7 |
-
"d_model": 768,
|
8 |
-
"decoder_start_token_id": 0,
|
9 |
-
"dropout_rate": 0.1,
|
10 |
-
"eos_token_id": 1,
|
11 |
-
"feed_forward_proj": "relu",
|
12 |
-
"gradient_checkpointing": false,
|
13 |
-
"initializer_factor": 1.0,
|
14 |
-
"is_encoder_decoder": true,
|
15 |
-
"layer_norm_epsilon": 1e-06,
|
16 |
-
"model_type": "t5",
|
17 |
-
"n_positions": 512,
|
18 |
-
"num_decoder_layers": 12,
|
19 |
-
"num_heads": 12,
|
20 |
-
"num_layers": 12,
|
21 |
-
"output_past": true,
|
22 |
-
"pad_token_id": 0,
|
23 |
-
"relative_attention_num_buckets": 32,
|
24 |
-
"task_specific_params": {
|
25 |
-
"summarization": {
|
26 |
-
"early_stopping": true,
|
27 |
-
"length_penalty": 2.0,
|
28 |
-
"max_length": 200,
|
29 |
-
"min_length": 30,
|
30 |
-
"no_repeat_ngram_size": 3,
|
31 |
-
"num_beams": 4,
|
32 |
-
"prefix": "summarize: "
|
33 |
-
},
|
34 |
-
"translation_en_to_de": {
|
35 |
-
"early_stopping": true,
|
36 |
-
"max_length": 300,
|
37 |
-
"num_beams": 4,
|
38 |
-
"prefix": "translate English to German: "
|
39 |
-
},
|
40 |
-
"translation_en_to_fr": {
|
41 |
-
"early_stopping": true,
|
42 |
-
"max_length": 300,
|
43 |
-
"num_beams": 4,
|
44 |
-
"prefix": "translate English to French: "
|
45 |
-
},
|
46 |
-
"translation_en_to_ro": {
|
47 |
-
"early_stopping": true,
|
48 |
-
"max_length": 300,
|
49 |
-
"num_beams": 4,
|
50 |
-
"prefix": "translate English to Romanian: "
|
51 |
-
}
|
52 |
-
},
|
53 |
-
"transformers_version": "4.9.0.dev0",
|
54 |
-
"use_cache": true,
|
55 |
-
"vocab_size": 32128
|
56 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
events.out.tfevents.1625682591.t1v-n-a0c138ef-w-0.124617.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5c069e81c193f5ba7a9c8cff114c5522e13d8efd16e2e8c055c880bf5010f334
|
3 |
-
size 736165
|
|
|
|
|
|
|
|
flax_model.msgpack
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:12aea1d6f15b37764f5615dcb6d6bc6cc56e7d74cd3ce88cdd0469817b5a9c29
|
3 |
-
size 891625348
|
|
|
|
|
|
|
|
src/{preparaing_recipe_nlg_dataset.py → create_dataset.py}
RENAMED
@@ -114,6 +114,7 @@ def main():
|
|
114 |
|
115 |
return {
|
116 |
"inputs": ner,
|
|
|
117 |
"targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
118 |
}
|
119 |
|
|
|
114 |
|
115 |
return {
|
116 |
"inputs": ner,
|
117 |
+
# "targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
118 |
"targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
119 |
}
|
120 |
|
src/run.sh
CHANGED
@@ -5,6 +5,7 @@ export LANG=C.UTF-8
|
|
5 |
|
6 |
export OUTPUT_DIR=/home/m3hrdadfi/code/t5-recipe-generation
|
7 |
export MODEL_NAME_OR_PATH=t5-base
|
|
|
8 |
export NUM_BEAMS=3
|
9 |
|
10 |
export TRAIN_FILE=/home/m3hrdadfi/code/data/train.csv
|
|
|
5 |
|
6 |
export OUTPUT_DIR=/home/m3hrdadfi/code/t5-recipe-generation
|
7 |
export MODEL_NAME_OR_PATH=t5-base
|
8 |
+
# export MODEL_NAME_OR_PATH=flax-community/t5-recipe-generation
|
9 |
export NUM_BEAMS=3
|
10 |
|
11 |
export TRAIN_FILE=/home/m3hrdadfi/code/data/train.csv
|
src/run_recipe_nlg_flax.py
CHANGED
@@ -21,6 +21,7 @@ Fine-tuning the library models for recipe-generation.
|
|
21 |
import logging
|
22 |
import os
|
23 |
import random
|
|
|
24 |
import sys
|
25 |
import time
|
26 |
from dataclasses import dataclass, field
|
@@ -375,7 +376,7 @@ def main():
|
|
375 |
data_files["test"] = data_args.test_file
|
376 |
extension = data_args.test_file.split(".")[-1]
|
377 |
|
378 |
-
|
379 |
dataset = load_dataset(
|
380 |
extension,
|
381 |
data_files=data_files,
|
@@ -551,10 +552,30 @@ def main():
|
|
551 |
bleu = load_metric("sacrebleu")
|
552 |
wer = load_metric("wer")
|
553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
def postprocess_text(preds, labels):
|
555 |
-
preds = [pred.strip() for pred in preds]
|
556 |
-
labels_bleu = [[label.strip()] for label in labels]
|
557 |
-
labels_wer = [label.strip() for label in labels]
|
558 |
|
559 |
return preds, [labels_bleu, labels_wer]
|
560 |
|
@@ -846,11 +867,6 @@ def main():
|
|
846 |
push_to_hub=training_args.push_to_hub,
|
847 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
848 |
)
|
849 |
-
tokenizer.save_pretrained(
|
850 |
-
training_args.output_dir,
|
851 |
-
push_to_hub=training_args.push_to_hub,
|
852 |
-
commit_message=f"Saving tokenizer step {cur_step}",
|
853 |
-
)
|
854 |
|
855 |
|
856 |
if __name__ == "__main__":
|
|
|
21 |
import logging
|
22 |
import os
|
23 |
import random
|
24 |
+
import re
|
25 |
import sys
|
26 |
import time
|
27 |
from dataclasses import dataclass, field
|
|
|
376 |
data_files["test"] = data_args.test_file
|
377 |
extension = data_args.test_file.split(".")[-1]
|
378 |
|
379 |
+
logger.info(data_files)
|
380 |
dataset = load_dataset(
|
381 |
extension,
|
382 |
data_files=data_files,
|
|
|
552 |
bleu = load_metric("sacrebleu")
|
553 |
wer = load_metric("wer")
|
554 |
|
555 |
+
def skip_special_tokens_text(text):
|
556 |
+
new_text = []
|
557 |
+
for word in text.split():
|
558 |
+
word = word.strip()
|
559 |
+
if word:
|
560 |
+
if word not in special_tokens:
|
561 |
+
new_text.append(word)
|
562 |
+
|
563 |
+
return " ".join(new_text)
|
564 |
+
|
565 |
+
def skip_special_tokens_texts(texts):
|
566 |
+
if isinstance(texts, list):
|
567 |
+
new_texts = [skip_special_tokens_text(text) for text in texts]
|
568 |
+
elif isinstance(texts, str):
|
569 |
+
new_texts = skip_special_tokens_text(texts)
|
570 |
+
else:
|
571 |
+
new_texts = []
|
572 |
+
|
573 |
+
return new_texts
|
574 |
+
|
575 |
def postprocess_text(preds, labels):
|
576 |
+
preds = [skip_special_tokens_texts(pred.strip()) for pred in preds]
|
577 |
+
labels_bleu = [[skip_special_tokens_texts(label.strip())] for label in labels]
|
578 |
+
labels_wer = [skip_special_tokens_texts(label.strip()) for label in labels]
|
579 |
|
580 |
return preds, [labels_bleu, labels_wer]
|
581 |
|
|
|
867 |
push_to_hub=training_args.push_to_hub,
|
868 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
869 |
)
|
|
|
|
|
|
|
|
|
|
|
870 |
|
871 |
|
872 |
if __name__ == "__main__":
|