m3hrdadfi commited on
Commit
6d1f4f4
1 Parent(s): b6de320

Fix some bugs

Browse files
config.json DELETED
@@ -1,56 +0,0 @@
1
- {
2
- "architectures": [
3
- "T5ForConditionalGeneration"
4
- ],
5
- "d_ff": 3072,
6
- "d_kv": 64,
7
- "d_model": 768,
8
- "decoder_start_token_id": 0,
9
- "dropout_rate": 0.1,
10
- "eos_token_id": 1,
11
- "feed_forward_proj": "relu",
12
- "gradient_checkpointing": false,
13
- "initializer_factor": 1.0,
14
- "is_encoder_decoder": true,
15
- "layer_norm_epsilon": 1e-06,
16
- "model_type": "t5",
17
- "n_positions": 512,
18
- "num_decoder_layers": 12,
19
- "num_heads": 12,
20
- "num_layers": 12,
21
- "output_past": true,
22
- "pad_token_id": 0,
23
- "relative_attention_num_buckets": 32,
24
- "task_specific_params": {
25
- "summarization": {
26
- "early_stopping": true,
27
- "length_penalty": 2.0,
28
- "max_length": 200,
29
- "min_length": 30,
30
- "no_repeat_ngram_size": 3,
31
- "num_beams": 4,
32
- "prefix": "summarize: "
33
- },
34
- "translation_en_to_de": {
35
- "early_stopping": true,
36
- "max_length": 300,
37
- "num_beams": 4,
38
- "prefix": "translate English to German: "
39
- },
40
- "translation_en_to_fr": {
41
- "early_stopping": true,
42
- "max_length": 300,
43
- "num_beams": 4,
44
- "prefix": "translate English to French: "
45
- },
46
- "translation_en_to_ro": {
47
- "early_stopping": true,
48
- "max_length": 300,
49
- "num_beams": 4,
50
- "prefix": "translate English to Romanian: "
51
- }
52
- },
53
- "transformers_version": "4.9.0.dev0",
54
- "use_cache": true,
55
- "vocab_size": 32128
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
events.out.tfevents.1625682591.t1v-n-a0c138ef-w-0.124617.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c069e81c193f5ba7a9c8cff114c5522e13d8efd16e2e8c055c880bf5010f334
3
- size 736165
 
 
 
 
flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12aea1d6f15b37764f5615dcb6d6bc6cc56e7d74cd3ce88cdd0469817b5a9c29
3
- size 891625348
 
 
 
 
src/{preparaing_recipe_nlg_dataset.py → create_dataset.py} RENAMED
@@ -114,6 +114,7 @@ def main():
114
 
115
  return {
116
  "inputs": ner,
 
117
  "targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
118
  }
119
 
 
114
 
115
  return {
116
  "inputs": ner,
117
+ # "targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
118
  "targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
119
  }
120
 
src/run.sh CHANGED
@@ -5,6 +5,7 @@ export LANG=C.UTF-8
5
 
6
  export OUTPUT_DIR=/home/m3hrdadfi/code/t5-recipe-generation
7
  export MODEL_NAME_OR_PATH=t5-base
 
8
  export NUM_BEAMS=3
9
 
10
  export TRAIN_FILE=/home/m3hrdadfi/code/data/train.csv
 
5
 
6
  export OUTPUT_DIR=/home/m3hrdadfi/code/t5-recipe-generation
7
  export MODEL_NAME_OR_PATH=t5-base
8
+ # export MODEL_NAME_OR_PATH=flax-community/t5-recipe-generation
9
  export NUM_BEAMS=3
10
 
11
  export TRAIN_FILE=/home/m3hrdadfi/code/data/train.csv
src/run_recipe_nlg_flax.py CHANGED
@@ -21,6 +21,7 @@ Fine-tuning the library models for recipe-generation.
21
  import logging
22
  import os
23
  import random
 
24
  import sys
25
  import time
26
  from dataclasses import dataclass, field
@@ -375,7 +376,7 @@ def main():
375
  data_files["test"] = data_args.test_file
376
  extension = data_args.test_file.split(".")[-1]
377
 
378
- print(data_files)
379
  dataset = load_dataset(
380
  extension,
381
  data_files=data_files,
@@ -551,10 +552,30 @@ def main():
551
  bleu = load_metric("sacrebleu")
552
  wer = load_metric("wer")
553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  def postprocess_text(preds, labels):
555
- preds = [pred.strip() for pred in preds]
556
- labels_bleu = [[label.strip()] for label in labels]
557
- labels_wer = [label.strip() for label in labels]
558
 
559
  return preds, [labels_bleu, labels_wer]
560
 
@@ -846,11 +867,6 @@ def main():
846
  push_to_hub=training_args.push_to_hub,
847
  commit_message=f"Saving weights and logs of step {cur_step}",
848
  )
849
- tokenizer.save_pretrained(
850
- training_args.output_dir,
851
- push_to_hub=training_args.push_to_hub,
852
- commit_message=f"Saving tokenizer step {cur_step}",
853
- )
854
 
855
 
856
  if __name__ == "__main__":
 
21
  import logging
22
  import os
23
  import random
24
+ import re
25
  import sys
26
  import time
27
  from dataclasses import dataclass, field
 
376
  data_files["test"] = data_args.test_file
377
  extension = data_args.test_file.split(".")[-1]
378
 
379
+ logger.info(data_files)
380
  dataset = load_dataset(
381
  extension,
382
  data_files=data_files,
 
552
  bleu = load_metric("sacrebleu")
553
  wer = load_metric("wer")
554
 
555
+ def skip_special_tokens_text(text):
556
+ new_text = []
557
+ for word in text.split():
558
+ word = word.strip()
559
+ if word:
560
+ if word not in special_tokens:
561
+ new_text.append(word)
562
+
563
+ return " ".join(new_text)
564
+
565
+ def skip_special_tokens_texts(texts):
566
+ if isinstance(texts, list):
567
+ new_texts = [skip_special_tokens_text(text) for text in texts]
568
+ elif isinstance(texts, str):
569
+ new_texts = skip_special_tokens_text(texts)
570
+ else:
571
+ new_texts = []
572
+
573
+ return new_texts
574
+
575
  def postprocess_text(preds, labels):
576
+ preds = [skip_special_tokens_texts(pred.strip()) for pred in preds]
577
+ labels_bleu = [[skip_special_tokens_texts(label.strip())] for label in labels]
578
+ labels_wer = [skip_special_tokens_texts(label.strip()) for label in labels]
579
 
580
  return preds, [labels_bleu, labels_wer]
581
 
 
867
  push_to_hub=training_args.push_to_hub,
868
  commit_message=f"Saving weights and logs of step {cur_step}",
869
  )
 
 
 
 
 
870
 
871
 
872
  if __name__ == "__main__":