ariahmed commited on
Commit
e489264
·
verified ·
1 Parent(s): 5f93c59

Upload folder using huggingface_hub

Browse files
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ .DS_Store
162
+ train.csv
163
+ test.csv
164
+ Kurd-Spell/
165
+ tokenizer
166
+ sn_project
167
+ notes.md
168
+
169
+
170
+ # Data dir
171
+ data/*
172
+ !data/words.json
173
+ !data/asosoft_benchmark.csv
174
+ !data/Sorani-Arabic.csv
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,84 @@
1
  ---
2
- title: Kurd Spell App
3
- emoji: 📊
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.4.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: kurd-spell-app
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.4.0
 
 
6
  ---
7
+ # Central Kurdish Neural Spell Corrector
8
+ <p align="center">
9
+ <img src="https://www.razhan.ai/_next/image?url=/static/images/projects/spell-checker.webp&w=1200&q=75" alt="Banner Image" height="240" width="1200">
10
+ <br>
11
+ <a href="https://huggingface.co/razhan/bart-kurd-spell-base">
12
+ [🔥 Best model]
13
+ </a>
14
+ <a href="https://huggingface.co/models?search=bart-kurd-spell">
15
+ [📀 Models]
16
+ </a>
17
+ <a href="https://huggingface.co/spaces/razhan/Kurd-Spell">
18
+ [🤗 Demo]
19
+ </a>
20
+ </p>
21
+
22
+
23
+
24
+
25
+ > **Note:** The documentation for this project is currently being written. I am working hard to make this project easily hackable so people can add new heuristics and train more models.
26
+
27
+ This repository contains a collection of neural spell correctors for the Central Kurdish language.These models have been trained on an extensive corpus of synthetically generated data. They are able to correct a wide range of spelling errors, including typos and grammatical errors.
28
+
29
+
30
+ Using various heuristics, we generate a rich dataset by mapping sequences containing misspellings to the correct sequence. We do this by randomly inserting valid characters, deleting characters or patterns, substituting characters with random ones or their keyboard neighbors, swapping two adjacent characters, shuffling sentences, and replacing specific predefined patterns with targeted alternatives.
31
+
32
+
33
+
34
+ ## Experiments
35
+ The error injection framework in `prepare_data` offers a method to inject errors according to a distortion ratio. I conducted the following experiments to determine the optimal ratio that allows the model to achieve the lowest Word Error Rate (WER) and Character Error Rate (CER) on the synthetic test set.
36
+ | Model Name | Dataset Distortion| CER | WER |
37
+ |------------------------------------------------------------------|-------------------|-------|--------|
38
+ | [bart-base](razhan/bart-kurd-spell-base-05) | 5% | 5.39% | 34.73% |
39
+ | [bart-base](razhan/bart-kurd-spell-base-05) | 10% | 2.15% | 11.19% |
40
+ | [bart-base](https://huggingface.co/razhan/bart-kurd-spell-base-05_10)| Mixed (5% + 10%)| **1.54%** | **8.31%** |
41
+ | [bart-base](https://huggingface.co/razhan/bart-kurd-spell-base) | 15% | 2.17% | 12.3% |
42
+
43
+
44
+ ## Evaluation on ASOSOFT Spelling Benchmark
45
+ The benchmark for this [project](https://github.com/AsoSoft/Central-Kurdish-Spelling-dataset) is exclusively designed for single-word spelling corrections. The script `create_asosoft_benchmark.py` processes each word from the Amani dataset by searching for sentences with the correct spelling, checking if the sentence has not been included in `train.csv` and replaces it with the provided misspelling. This is hacky way to get a gold-standard benchmark. The current best-performing model achieves the following results:
46
+
47
+ | Metric | Value |
48
+ |----------|--------|
49
+ | CER | 9.6545 |
50
+ | WER | 21.7558|
51
+ | Bleu | 68.1724|
52
+
53
+ ## Evluation on Sorani Script Normalization Benchmark
54
+ The final generated dataset is also concatenated with the training dataset from [Script Normalization for Unvonventional Writing](https://github.com/sinaahmadi/ScriptNormalization/tree/main) project. Therefore, the model not only correct spelling but also normalize unconventional writings. "Unconventional Writing" means using the writing system of one language to write in another language.
55
+
56
+ They also employ a similiar approach to generate their data. But it's not wise to evaluate your model on the synthetic test set since the model can memorize the underlying patterns from the training set. Hence they provide a gold-standard benchmark for Central Kurdish and they use `Bleu` & `chrF` to measure the performance of their model.
57
+
58
+ | Model | Bleu | chrF |
59
+ |-----------------------|-------|-------|
60
+ | Script Normalization | 12.7 | 69.6 |
61
+ | Bart-kurd-spell-base | 13.8 | 73.9 |
62
+
63
+ > Keep in mind of both these models have seen the same data for script normalization but our model is performing slighly better due to the additional data for spell correction.
64
+
65
+
66
+ ## Train a New Model
67
+ Since the problem is framed as mapping a sequence containing misspellings to a correct sequence, we can train different econder-decoder models such as T5.
68
+ 1. Run [`train_tokenizer.py`](train_tokenizer.py) to build tokenizer for your chosen model with `--tokenizer_name` argument.
69
+ 2. Create `data.txt` and put it in [`data`](data) dir. Check [`inspect_data.ipynb`](inspect_data.ipynb).
70
+ 3. Check the arguments of [`pepare_data/process_data.py`](pepare_data/process_data.py) and run it to get `train.csv` and `test.csv`
71
+ 4. Change the arguments in [`train.sh`](train.sh) if your want to train a different model other than Bart. In case you want to train T5, you need to add `--source_prefix "correct: "`.
72
+ 5. Evaluate the model on both [`data/asosoft_benchmark.csv`](data/asosoft_benchmark.csv) and [`data/Sorani-Arabic.csv`](data/Sorani-Arabic.csv) using [`eval.sh`](eval.sh)
73
+
74
+ ## Observations
75
+ Different heuristics could be added to the pipeline, for example, replacing ر at the start of every word with ڕ or replacing ك with ک. These aforementioned examples occur quite often in Central Kurdish texts online. But both of these problems can be solved using rule-based instead of being learned from the data. It is more practical to address such problems using rule-based solutions such as [`KLPT`](https://github.com/sinaahmadi/klpt).
76
+
77
+ But in case you can think of more heuristics, they can be easily added to the pipeline in the [`get_text_distorter`](prepare_data/processors.py#L111) function.
78
+
79
+ PRs with additional models, evaluation, or data generation heuristics are welcome! 👍
80
 
81
+ ## References
82
+ - https://arxiv.org/abs/1910.13461
83
+ - https://www.researchsquare.com/article/rs-2974359/v1
84
+ - https://arxiv.org/abs/2305.16407
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from difflib import Differ
3
+ from transformers import pipeline
4
+
5
+ model_id = "razhan/bart-kurd-spell-base"
6
+ # spell_corrector = pipeline("text2text-generation", model=model_id, return_all_scores=True)
7
+ spell_corrector = pipeline("text2text-generation", model=model_id, max_length=1024)
8
+
9
+
10
+
11
+ def correct_spell(text):
12
+ d = Differ()
13
+ if text is None:
14
+ text = ""
15
+ corrected = spell_corrector(text)[0]['generated_text']
16
+
17
+ return [
18
+ (token[2:], token[0] if token[0] != " " else None)
19
+ for token in d.compare(text, corrected)
20
+ ], corrected
21
+
22
+
23
+
24
+ demo = gr.Interface(
25
+ correct_spell,
26
+ [
27
+ gr.Textbox(
28
+ label="Input text",
29
+ info="Initial text to be corrected",
30
+ lines=3,
31
+ value="نوووسینێکی ڕااست بێهەڵە",
32
+ rtl=True
33
+ ),
34
+ ],
35
+ outputs=[
36
+ gr.HighlightedText(
37
+ label="Diff",
38
+ combine_adjacent=True,
39
+ show_legend=True,
40
+ color_map={"-": "pink", "+": "green"},
41
+ rtl=True,
42
+ # container=True,
43
+ elem_id="kurdi"
44
+ ),
45
+ gr.Textbox(label="Corrected Text", rtl=True, container=True)
46
+ ],
47
+ examples=[
48
+ "حکومەتلە گفتوگۆحانی پەرلەماندا لەسەربودجەی نوێ ڕایگەیاند کە لە دەنگدانلەسەر بودجە بەردەوام دەبێت",
49
+ "ژنەڤ کاندغدێکی کورد نەشتەرگەری بۆکەا",
50
+ "فەستبخەرکرانی سێ هاووڵاتی لە شاری بۆکانلە لاین هێزە ئەمنییکەانەوە",
51
+ "ئەم وێنجانەی وخارەوەش چەند ێونەیەکی دەزپێرکاوی مۆبایلەکەن",
52
+ "خۆگزە توانیبام ژیان لە دیداری یەکەی ژاچگرێ بدەم",
53
+ "هەرفەرمانبەرێک بەناشچایستە پلەی نوەزیفیوەرگرتبێتلێیدەسەرنێتەەو",
54
+ "ماوەیەکەدەست ەب ئاامدەکسری کرا٦وە بۆ بەڕێوەچوونی ەششەمین فیستیڤاڵینێودەوڵەتیی هەولێرب ۆ شانۆ",
55
+ "ەڵم ئارەزوومە کە فیلمێک لە سەرحۆریەکانی ێجەریای نێوچیڕۆکەکانیشەوان عەرەبیەوە بەرخهەم بهێنم",
56
+ "پارەی ئەلکتترۆنیکی هیان راوی دیجیتاڵ جۆرە راوێکە کە تەنیا بە شێوەی ئەلیکترۆنیکی لەبەردەستەایە"
57
+
58
+
59
+ ],
60
+ title="Central Kurdish Neurl Spell Correction",
61
+ # description="This is made as a fun side project, it's not to be relied on for production.",
62
+ css="""
63
+ #kurdi {
64
+ text-align: right;
65
+ }
66
+ """,
67
+ theme=gr.themes.Base(
68
+ primary_hue="pink",
69
+ secondary_hue="stone",
70
+ text_size=gr.themes.sizes.text_lg,
71
+ spacing_size=gr.themes.sizes.spacing_lg,
72
+ radius_size=gr.themes.sizes.radius_lg,
73
+ font=gr.themes.GoogleFont("Noto Sans"),
74
+
75
+ ),
76
+ allow_flagging='auto'
77
+ )
78
+ if __name__ == "__main__":
79
+ demo.launch()
ckb_helpers.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from klpt.preprocess import Preprocess
3
+ from klpt.tokenize import Tokenize
4
+
5
+ import unicodedata
6
+
7
+ preprocessor_ckb = Preprocess("Sorani", "Arabic", numeral="Arabic")
8
+ tokenizer_ckb = Tokenize("Sorani", "Arabic")
9
+
10
+
11
+ unify_numbers = {
12
+ "٠|۰": "0",
13
+ "١|۱": "1",
14
+ "٢|۲": "2",
15
+ "٣|۳": "3",
16
+ "٤|۴": "4",
17
+ "٥|۵": "5",
18
+ "٦|۶": "6",
19
+ "٧|۷": "7",
20
+ "٨|۸": "8",
21
+ "٩|۹": "9"
22
+ }
23
+
24
+ # Taken from AsoSoft library
25
+ def number_to_word(text):
26
+ # convert numbers to latin
27
+ for k, v in unify_numbers.items():
28
+ text = re.sub(k, v, text)
29
+
30
+ text = re.sub(r"([0-9]{1,3})[,،](?=[0-9]{3})", r"\1", text); # remove thousend seperator 12,345,678 => 12345678
31
+ text = re.sub(r"(?<![0-9])-([0-9]+)", r"ناقس \1", text); # negative
32
+ text = text.replace("٪", "%") # Replace arabic percent sign with latin
33
+ text = re.sub(r"(?<![0-9])% ?([0-9]+)", r"لە سەددا \1", text); # percent sign before
34
+ text = re.sub(r"([0-9]+) ?%", r"\1 لە سەد", text); # percent sign after
35
+ text = re.sub(r"\$ ?([0-9]+(\.[0-9]+)?)", r"\1 دۆلار", text) # $ querency
36
+ text = re.sub(r"£ ?([0-9]+(\.[0-9]+)?)", r"\1 پاوەن", text) # £ querency
37
+ text = re.sub(r"€ ?([0-9]+(\.[0-9]+)?)", r"\1 یۆرۆ", text) # € querency
38
+
39
+ # convert float numbers
40
+ text = re.sub(r"([0-9]+)\.([0-9]+)", lambda x: float_name(x.group(1), x.group(2)), text)
41
+
42
+ # convert remaining integr numbers
43
+ text = re.sub(r"([0-9]+)", lambda match: integer_name(match.group(1)), text)
44
+
45
+ return text
46
+
47
+ def float_name(integerPart, decimalPart):
48
+ zeros = re.search("^0+", decimalPart)
49
+ point = " پۆینت "
50
+ if(zeros):
51
+ point = point + re.sub("0", " سفر ", zeros[0])
52
+ return integer_name(integerPart) + point + integer_name(decimalPart)
53
+
54
+ ones = ["", "یەک", "دوو", "سێ", "چوار", "پێنج", "شەش", "حەوت", "هەشت", "نۆ"]
55
+ teens = [ "دە", "یازدە", "دوازدە", "سێزدە", "چواردە", "پازدە", "شازدە", "حەڤدە", "هەژدە", "نۆزدە" ]
56
+ tens = [ "", "", "بیست", "سی", "چل", "پەنجا", "شەست", "هەفتا", "هەشتا", "نەوەد"]
57
+ hundreds = ["", "سەد", "دووسەد", "سێسەد", "چوارسەد", "پێنسەد", "شەشسەد", "حەوتسەد", "هەشتسەد", "نۆسەد"]
58
+ thousands = ["", " هەزار", " ملیۆن", " ملیار", " ترلیۆن", " کوادرلیۆن", " کوینتلیۆن"]
59
+
60
+ def integer_name(inputInteger):
61
+ output = ""
62
+ if (inputInteger != "0"):
63
+ temp = inputInteger
64
+ for i in range(0, len(inputInteger), 3):
65
+ matched_numbers = re.findall(r"[0-9]{1,3}$", temp)
66
+ currentThree = matched_numbers[0] if matched_numbers else ""
67
+
68
+ temp = temp[:len(temp) - len(currentThree)]
69
+ currentThree = currentThree.rjust(3, '0')
70
+ C = int(currentThree[0])
71
+ X = int(currentThree[1])
72
+ I = int(currentThree[2])
73
+ conjunction1 = " و " if (C != 0) and (X != 0 or I != 0) else ""
74
+ conjunction2 = " و " if X != 0 and I != 0 else ""
75
+
76
+ if (X == 1):
77
+ currentThree = hundreds[C] + conjunction1 + teens[I]
78
+ else:
79
+ currentThree = hundreds[C] + conjunction1 + tens[X] + conjunction2 + ones[I]
80
+
81
+ currentThree += "" if currentThree == "" else thousands[i // 3]
82
+
83
+ conjunction3 = "" if output == "" else " و "
84
+ if (currentThree != ""):
85
+ output = currentThree + conjunction3 + output
86
+ output = output.replace("یەک هەزار", "هەزار")
87
+ else: # if input number = 0
88
+ output = "سفر"
89
+ return output
90
+
91
+
92
+
93
+
94
+ def replace_words_in_corpus(sentence):
95
+ modified_corpus = []
96
+
97
+ words = sentence.split()
98
+ modified_words = []
99
+
100
+ for word in words:
101
+ if word in word_replacements:
102
+ modified_words.append(word_replacements[word])
103
+ else:
104
+ modified_words.append(word)
105
+
106
+ modified_sentence = " ".join(modified_words)
107
+
108
+ return modified_sentence
109
+
110
+ # put this in a json file
111
+ word_replacements = {
112
+ "ههڵاڵەەي": "هەڵاڵەی",
113
+ "وهەمهەمه": "وهەمهەمه",
114
+ "ئهباتههوه": "ئەباتەوە",
115
+ "بەخءرایی": "بەخێرایی",
116
+ "ئیثانۆڵ": "ئیسانۆڵ",
117
+ "عەبدوڵڵاهـ": "عەبدوڵڵا",
118
+ "کولاهـ": "کولاه",
119
+ "ئاھ": "ئاه",
120
+ }
121
+
122
+
123
+ char_replacements = {
124
+ '\u200e': '',
125
+ '\u200f': '',
126
+ '\u200c': '',
127
+ 'õ': '',
128
+ 'ھ': 'ه'
129
+ }
130
+ def apply_char_replacements(text: str):
131
+
132
+ for old, new in char_replacements.items():
133
+ text = text.replace(old, new)
134
+ return text
135
+
136
+
137
+ def remove_arabic_alphabets(text: str):
138
+ """
139
+ Removes ``Arabic`` words and digits from a ``text``
140
+
141
+ Args:
142
+ text (str): Sorani text
143
+ Returns:
144
+ str: ``str`` object with arabic alphabets removed
145
+ """
146
+ characters = "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٱ"
147
+ table = str.maketrans({key: None for key in characters})
148
+ return text.translate(table)
149
+
150
+
151
+
152
+ def filtered_arabic_characters():
153
+ kurdish_characters = set("ئابپتجچحخدرڕزژسشعغفڤقکگلڵمنهەوووۆیێ")
154
+ arabic_characters = set("ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٱ")
155
+
156
+ # Create a new set of Arabic characters without the Kurdish characters
157
+ filtered_arabic_characters = arabic_characters - kurdish_characters
158
+
159
+ return ''.join(filtered_arabic_characters)
160
+
161
+
162
+
163
+ def is_arabic_string(text):
164
+ """Returns True if the text contains any Arabic characters, False otherwise."""
165
+ # arabic_characters = set("ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٱ")
166
+ arabic_characters = filtered_arabic_characters()
167
+ for ch in text:
168
+ if ch in arabic_characters:
169
+ return True
170
+ return False
171
+
172
+ def contains_arabic(text):
173
+ arabic_characters = filtered_arabic_characters()
174
+ return any(char in arabic_characters for char in text)
175
+
176
+
177
+ def is_english_string(text):
178
+ """Returns True if the text contains only English characters, False otherwise."""
179
+ english_pattern = re.compile(r'[a-zA-Z]')
180
+ return bool(english_pattern.search(text))
181
+
182
+
183
+ def remove_english_alphabets(text: str):
184
+ """
185
+ Removes ``English`` words and digits from a ``text``
186
+ """
187
+ characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
188
+ table = str.maketrans({key: None for key in characters})
189
+ return text.translate(table)
190
+
191
+
192
+
193
+
194
+ def resolve_ae(text):
195
+ """
196
+ This function takes a text input in Central Kurdish (Sorani) script and performs a series of character replacements
197
+ to standardize variations in the script. Specifically, it addresses cases where the character 'ە' (Arabic letter
198
+ AE) may be used in different contexts.
199
+ """
200
+ # First replace all occurrences of 'ه' with 'ە'
201
+ text = re.sub("ه", "ە", text)
202
+ # Replace specific combinations with 'ها', 'هێ', and 'ه'
203
+ text = re.sub("ەا", "ها", text) # Replace ەا with ها
204
+ text = re.sub("ەێ", "هێ", text) # Replace ەێ with هێ
205
+ text = re.sub("ەۆ", "هۆ", text) # Replace ەۆ with هۆ
206
+
207
+ # Replace ە (AE) at the beginning of a word with ه (HEH)
208
+ text = re.sub(r"\b(ە\w*)", lambda match: "ه" + match.group(1)[1:], text)
209
+
210
+ # Replace ALEF+AE with ALEF+HEH
211
+ text = re.sub("اە", "اه", text)
212
+
213
+ # Special words should go here before the replcement of 'ە' at the end of the word
214
+ # Special case: گەهـ or گاهـ but without the tatweel since tatweel is not a phoneme in Kurdish and it will be a class for the model
215
+ text = re.sub(r'\bگەە[-ـ]?\b', "گەه", text)
216
+
217
+ # Replace 'ەە' at the beginning and end with 'هە'
218
+ text = re.sub(r"\bەە|ەە\b", "هە", text)
219
+
220
+ # Special case if two AEs come before ۆ it should be replaced with AE+HEH
221
+ text = re.sub(r"ەە(?=ۆ)", "ەه", text)
222
+
223
+ # Special case if two AEs come after either و or ب or ئ or ڕ or ق or ز they should be replaced with AE+HEH
224
+ text = re.sub(r"(?<=\b[بوئڕقزژ])ەە", "ەه", text)
225
+ # The following special case should happen after the previous special case and before the following speciall case
226
+ # Special case when two words are together with waw and the the AEs after the waw becomes HEH+AE
227
+ text = re.sub(r'(?<=و)ەە(?=\w)', "هە", text)
228
+
229
+ # Replace Three AEs with AE+HEH+AE (This has to be run before the following special case so words like لەهەوادا will not be ruined)
230
+ text = re.sub(r"(?<=\w)ەەە(?=\w)", "ەهە", text)
231
+
232
+ # Special case if two AEs are in the middle of a word and come before YEH ی or TCHEH چ or و they will be replaced with AE+HEH if the YEH or TCHEH are not at the END of the word
233
+ text = re.sub(r"(?<=\w)ەە(?=[چیو]\B)", "ەه", text)
234
+
235
+ # Replace 'ەە'AE+AE in the middle of a word with HEH+AE
236
+ text = re.sub(r"(?<=\w)ەە(?=\w)", "هە", text)
237
+
238
+ # Replace two AE with spaces in between with AE HEH
239
+ text = re.sub("ە ە", "ە ه", text)
240
+
241
+ # Replace all HEH DOACHASHMEE with HEH
242
+ # text = text.replace('ھ', 'ە')
243
+ return text
244
+
245
+ clean_punctuation = re.compile(r"(?<!\d)[.,;:'?!\/](?!\d)")
246
+ def remove_punctuation(text):
247
+ """Remove all punctuation from string, except if it's between digits"""
248
+ return clean_punctuation.sub("", text)
249
+
250
+
251
+ def extract_punctuation(text):
252
+ # Initialize an empty string to store the extracted punctuation
253
+ extracted_punctuation = ""
254
+
255
+ # Iterate through each character in the input text
256
+ for char in text:
257
+ # Check if the character is categorized as punctuation
258
+ if unicodedata.category(char).startswith('P'):
259
+ extracted_punctuation += char # Add it to the result
260
+
261
+ return set(extracted_punctuation)
262
+
263
+
264
+
265
+ ARABIC_PUCTUATIONS = "،؛۔٫٪؟"
266
+ CKB_PUNCTUATIONS = "!.:;?،؛؟«»" + ARABIC_PUCTUATIONS
267
+ KURDISH_CHARS = set(f"{CKB_PUNCTUATIONS}ئابپتجچحخدرڕزژسشعغفڤقکگلڵمنهەوووۆیێ٠١٢٣٤٥٦٧٨٩ ")
268
+
269
+ def contains_non_kurdish_characters(text):
270
+ # kurdish_characters = set("ئابپتجچحخدرڕزژسشعغفڤقکگلڵمنهەوووۆیێ٠١٢٣٤٥٦٧٨٩ ")
271
+ kurdish_characters = set(f"{CKB_PUNCTUATIONS}ئابپتجچحخدرڕزژسشعغفڤقکگلڵمنهەوووۆیێ٠١٢٣٤٥٦٧٨٩ ")
272
+ non_kurdish_chars = set(text) - kurdish_characters
273
+
274
+ return len(non_kurdish_chars) > 0
275
+
276
+
277
+ def keep_kurdish_characters(text):
278
+ kurdish_characters = set(f"{CKB_PUNCTUATIONS}ئابپتجچحخدرڕزژسشعغفڤقکگلڵمنهەوووۆیێ٠١٢٣٤٥٦٧٨٩ ")
279
+
280
+ cleaned_text = ''.join(char for char in text if char in kurdish_characters)
281
+ return cleaned_text
282
+
283
+
284
+
285
+ def remove_emojis(text):
286
+ emoji_pattern = re.compile("["
287
+ "\U0001F600-\U0001F64F" # Emoticons
288
+ "\U0001F300-\U0001F5FF" # Symbols & Pictographs
289
+ "\U0001F680-\U0001F6FF" # Transport & Map Symbols
290
+ "\U0001F700-\U0001F77F" # Alchemical Symbols
291
+ "\U0001F780-\U0001F7FF" # Geometric Shapes Extended
292
+ "\U0001F800-\U0001F8FF" # Supplemental Arrows-C
293
+ "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs
294
+ "\U0001FA00-\U0001FA6F" # Chess Symbols
295
+ "\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A
296
+ "\U00002702-\U000027B0" # Dingbats
297
+ "]+", flags=re.UNICODE)
298
+ return emoji_pattern.sub(r'', text)
299
+
300
+
301
+ def remove_language_families(text):
302
+ patterns = [
303
+ "[\u1100-\u11FF\u2E80-\u4DBF\u4E00-\u9FFF\uAC00-\uD7AF]+", # Asian scripts
304
+ "[\u0000-\u024F]+", # Basic Latin and Latin-1 Supplement
305
+ "[\u0400-\u04FF]+", # Cyrillic
306
+ "[\u0370-\u03FF]+", # Greek
307
+ "[\u0900-\u097F]+", # Devanagari
308
+ r"\u0B80-\u0BFF", # Tamil
309
+ r"\u4E00-\u9FFF", # Han
310
+ r"\u10A0-\u10FF", # Georgian
311
+ r"\u0C80-\u0CFF" # Kannada
312
+ ]
313
+
314
+ combined_pattern = re.compile("|".join(patterns))
315
+
316
+ cleaned_text = combined_pattern.sub(r'', text)
317
+ return cleaned_text
318
+
319
+
320
+ clean_punctuation = re.compile(r"(?<!\d)[.,;:'?!،.؟؛:](?!\d)")
321
+ def remove_punctuation(text):
322
+ """Remove all punctuation from string, except if it's between digits"""
323
+ return clean_punctuation.sub("", text)
324
+
325
+ def contains_repeated_ngram(window, n):
326
+ ngrams = generate_ngrams(window, n)
327
+ ngram_set = set(ngrams)
328
+ return len(ngrams) != len(ngram_set)
329
+
330
+
331
+ def generate_ngrams(text, n):
332
+ words = text.split()
333
+ output = []
334
+ for i in range(len(words)- n+1):
335
+ output.append(tuple(words[i:i+n]))
336
+ return output
337
+
338
+ def remove_repeated_ngram(text, n):
339
+ words = text.split()
340
+ output = []
341
+ for i in range(len(words)- n+1):
342
+ if not contains_repeated_ngram(" ".join(words[i:i+n]), n):
343
+ output.append(words[i])
344
+ return " ".join(output)
345
+
346
+ def normalize_punctuations(text: str) -> str:
347
+ # Replace , with ،
348
+ text = text.replace(',', '،')
349
+ # Replace ? with ؟
350
+ text = text.replace('?', '؟')
351
+ # Replace two or three of the same punctuation marks with a single one
352
+ text = re.sub(r'([.,;:?!،؛؟])\1{1,2}', r'\1', text)
353
+
354
+
355
+ # Replace double opening and closing parentheses with guillemets
356
+ text = re.sub(r'\(\(', '«', text)
357
+ text = re.sub(r'\)\)', '»', text)
358
+
359
+ # Normalize space around the guillemets and other punctuation marks
360
+ text = re.sub(r'\s*«\s*', ' «', text)
361
+ text = re.sub(r'\s*»\s*', '» ', text)
362
+
363
+ # Additional punctuation normalization
364
+ text = re.sub(r'\s*([,،؟])\s*', r'\1 ', text)
365
+
366
+ # Ensure there is no space before a guillemet at the beginning of the text or after a
367
+ # guillemet at the end of the text
368
+ text = re.sub(r'^\s*«', '«', text)
369
+ text = re.sub(r'»\s*$', '»', text)
370
+
371
+ # If multiple punctuation marks come after each other only keep the first one
372
+ # text = re.sub(r'([.!?؟،؛])\1+', r'\1', text)
373
+
374
+ # if conective punctuation marks come after each other only keep the first one
375
+ text = re.sub(r'([.!?؟،؛])\1+', r'\1', text)
376
+
377
+ # if punctuation marks come after each other with space between them like: ? ? ? keep the first one remove the rest
378
+ text = re.sub(r'([.!?؟،؛])\s\1+', r'\1', text)
379
+ # Trim leading and trailing spaces and return the normalized text
380
+ text = text.strip()
381
+ return text
382
+
383
+
384
+ def fix_sentence(sentence):
385
+
386
+ if sentence.startswith('"') and sentence.endswith('"'):
387
+ # we can remove trailing quotation marks as they do not affect the sentence
388
+ sentence = sentence[1:-1]
389
+
390
+ if sentence[-1] not in [".", "?", "!"]:
391
+ # append a full-stop to sentences that do not end in punctuation
392
+ sentence = sentence + "."
393
+ # sentence = sentence[:-1].translate(str.maketrans('', '', string.punctuation)) + sentence[-1]
394
+ return sentence
395
+
396
+
397
+ def add_period_abbreviations(text):
398
+
399
+ abbreviations = set(["پ", "د"]) # Add more abbreviations as needed
400
+
401
+ # Define a regular expression pattern to match a letter followed by a space and then a word character
402
+ pattern = re.compile(r'([{}]) (?=\w)'.format(''.join(abbreviations)))
403
+
404
+ # Use regex to add periods after the specified abbreviations with a space after the period
405
+ text = pattern.sub(r'\1. ', text)
406
+
407
+ # Add periods after each letter if "د" and "خ" appear together
408
+ text = re.sub(r'د\sخ|خ ?د|د\.?خ|خ\.?د', 'د. خ.', text)
409
+
410
+ # Abbreviated dates
411
+ # text = re.sub(r'\b(پ\. ز)\b', r'\1.', text)
412
+
413
+ return text
414
+
415
+
416
+ def process_text(text):
417
+ # text = replace_words_in_corpus(text)
418
+ text = resolve_ae(text)
419
+ # text = number_to_word(text)
420
+ text = preprocessor_ckb.preprocess(text)
421
+ # text = normalizer(text).strip()
422
+ text = remove_emojis(text)
423
+ text = normalize_punctuations(text)
424
+ text = fix_sentence(text)
425
+ text = apply_char_replacements(text)
426
+ return text
427
+
428
+ if __name__ == "__main__":
429
+ # text = "لە ساڵی 1999دا بڕی 40% لە پارەکەیان واتە $102.1 یان وەرگرت. 'õ'\u200c\u200f\u200e'ھ'"
430
+
431
+ # print(process_text(text))
432
+ # print(contains_non_kurdish_characters(text))
433
+ # text = "دەقی«کوردی » و ڕێنووس ،((خاڵبەندی )) چۆنە ؟"
434
+ # correct = "دەقی «کوردی» و ڕێنووس، «خاڵبەندی» چۆنە؟"
435
+ # print("Before punctuation normalization:", text)
436
+ # print("After punctuation normalization:", normalize_punctuations(text))
437
+ # print("Correct:\t\t\t", correct)
438
+ # print(normalize_punctuations(text) == correct)
439
+ # print(normalize_punctuations("ڕەوا بورهان 4 تەمموز ، کوردستانی سلێمانی?!!"))
440
+ # print(normalize_punctuations("یانەی کوردژین تکایە چۆن بە شی سە ڕە کی و لاوە کی بۆ مالپە ڕە کە م زیاد بکە م؟؟ ؟ ؟ لە سکرێپە یتی ژومیلە"))
441
+ # with open('data/data.ckb.txt', 'r', encoding='utf-8') as src_file:
442
+ # source_data = src_file.read()
443
+
444
+ # unified_data = normalize_punctuations(source_data)
445
+
446
+ # # Save the unified data to a new file
447
+ # with open('data/unified_data.txt', 'w', encoding='utf-8') as file:
448
+ # file.writelines(unified_data)
449
+
450
+ # print("Unified data saved to unified_data.txt")
451
+
452
+ text = "Hello ((Friend)) Hello , Friend World"
453
+ # print(remove_repeated_ngram(text, 2))
454
+ # print(remove_repeated_ngrams(text, ))
455
+ print(process_text(text))
create_asosoft_benchmark.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm
3
+ from ckb_helpers import *
4
+ df = pd.read_csv('data/asotest.csv')
5
+
6
+ data_df = pd.read_csv('data/data.txt', names=['text'])
7
+ train_df = pd.read_csv('train.csv')
8
+
9
+
10
+ data = []
11
+ pbar = tqdm(df.itertuples(), total=len(df))
12
+
13
+ for row in pbar:
14
+ incorrect_word = row.text
15
+ correct_word = row.summary
16
+
17
+ # look up sentences from data_df that contain correct_word and make only keep those rows that are not in train_df
18
+ sentences = data_df[data_df['text'].str.contains(correct_word, case=False, na=False)]
19
+ sentences = sentences[~sentences.text.isin(train_df.summary)]
20
+
21
+ pbar.set_description(f"Rows found after cross checking train data: {len(sentences)} for {correct_word}")
22
+ for r in sentences.head(1).itertuples():
23
+ new_sentence = r.text.replace(correct_word, incorrect_word)
24
+ data.append({"text": new_sentence, "summary": process_text(r.text)})
25
+ # drop that row so the final dataset doesn't include same sentence for two incorrect words
26
+ data_df.drop(index=r.Index, axis=0, inplace=True)
27
+
28
+
29
+
30
+ df = pd.DataFrame(data)
31
+ df.to_csv('asosoft_spell.csv', index=False)
data/Sorani-Arabic.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text,summary
2
+ عةمرت نةمينئ نةت توانى يةك بزمار بيبةى,عەمرت نەمینێ نەتوانی یەک بزمار ببەی.
3
+ بةرهةمةكاني ئةفين ئاسؤ,بەرهەمەکانی ئەڤین ئاسۆ
4
+ دانا بويته فليمي كجةن,دانا بووەتە فیلمی کچان.
5
+ به و برچه الی هاوریان,بەو پرچە ئەلێی هاوڕێیان
6
+ ئاخر خواية تاواني ئةم منالة نةكبةتة جية,ئاخر خودایە تاوانی ئەم مناڵە نەگبەتە چییە؟
7
+ به رنامه يه ك بودروست كردني بي ئابروي,بەرنامەیەک بۆ دروستکردنی بێئابڕویی
8
+ سةيري پيكةنينةكةي سونيا,سەیری پێکەنینەکەی سۆنیا
9
+ ئةبي كي كچ بدات بةوانة,ئەبێ کێ کچ بدات بەوانە؟
10
+ هةريم قةرةناوي,هەرێم قەرەناوی
11
+ خه تاي دانا بوو بو ئالان ياري بكردايه له كه ل ي ئه ي برده وه,خەتای دانا بوو، بۆ ئاڵان یاریی بکردایە لەگەڵی ئەی بردەوە
12
+ هه ى حه مرى خؤتو كه ناكه تو ميوانه كانت نه مينئ,هەی عەمری خۆت و کەناڵەکەت و میوانەکانت نەمێنێت
13
+ جوئن مه دن واعیساب که ن سئ کچ به رنامه پئشکه ش ده که ن,جوێن مەدەن، وا حیساب بکەن ٣ کچ بەرنامە پێشکەش دەکەن
14
+ بیره زن بلی توچیت دایه له میکیاج,پیرەژن، بلێ تۆ چیت داوە لە مکیاج!
15
+ جا توپه بةفرى ويكةت با باشتر بوو,جا تۆپەڵە بەفری پێکەتبا باشتر بوو
16
+ هةموو نةوعة كةوتنيكي تاقي كردةوة,هەموو نەوعە کەوتنێکی تاقی کردەوە
17
+ به س بيم بلين جواني ئه م مه يمونه له كويايه,بەس پێم بڵێن جوانیی ئەم مەیمونە لە کوێدایە؟
18
+ بؤ كؤمنته كان ناخوينيه وه,بۆ کۆمێنتەکان ناخوێنیتەوە؟
19
+ چند ناخؤشه تةريق بيتةوه,چەند ناخۆشە تەریق بیتەوە
20
+ دوو مه يموني هيناوه به ده مي يه ك پيده كه نن,دوو مەیموونی هێناوە بەدەمی یەکتر پێئەکەنن
21
+ باخوا يارا شئرا شئررررر,بە خوا یارا شێرە شێر.
22
+ ده ي رسقي ئه مانيش ببرن,دەی ڕزقی ئەمانیش ببڕن
23
+ فقيرا ملي شكا,فەقیرە ملی شکا
24
+ خوزكه جونكيله ش بان,خۆزگە جوانکیلەش بان
25
+ وةلاي تةسميل مةوة ريال كةي بةشةرة بةس بةرشة بةشةرة,وەلاهی تەسمیل مەبە، ڕیاڵ کەی بەشەرە، بەس بەرشە بەشەرە.
26
+ شەربت هەنار اخوی,شەربەت هەنار ئەخۆی؟
27
+ ده ست جاوت خوشيبي,دەست و چاوت خۆش بێت
28
+ ئه م حه مه يسك قورسه به جي وا زه عيف بوه,ئەم حەمەیە ئێسک قورسە، بەچی وا زەعیف بووە؟
29
+ يةعني ئاوي تةماتةك ئةوةي دةوي,یەعنی ئاوی تەماتەک ئەوەی دەوێ
30
+ كومپانباي دزيني ئوتومبيل يش په يدابوو,کۆمپانیای دزینی ئۆتۆمبێلیش پەیدا بوو
31
+ صةلاح بالابةرز,سەلاح باڵابەرز
32
+ له و بيكه نه نه ي ئه لي ته قه ده كا,لەو پێکەنینەی! ئەلێی تەقە دەکات
33
+ خوا که سیک رسوا بکه وه ک ئه مه ی لی ئه کا,خودا کەسێک ڕیسوا بکات، وەک ئەمەی لێ ئەکات
34
+ كوناح نيه سازان ميرديك كه جه لى هه بيت,گوناح نییە سازان مێردێکی کەچەڵی هەبێت
35
+ هةرچي شيتو پاتالة لةم نيتة تةرةماشةية كةسيكي عاقلمان نةبيني,هەرچی شێت و پاتاڵە لەم نێتە تەڕەماشەیە، کەسێکی ئاقڵمان نەبینی
36
+ سازان خوي كردوتة عةنكةبوت جامانةى لةسةره,سازان خۆی کردۆتە عەنکەبووت، جامانەی لەسەرە.
37
+ كوره كان بس نازانن ياري بكه ن,کوڕەکان بەس نازانن یاری بکەن
38
+ حاجي سةيفةديت زور بيژي ديارة,حاجی سەیفەدین ، زۆر بیژی دیارە
39
+ جند بةبي يشي قسةت كرد,چەند بەبێ ئیشی قسەت کرد.
40
+ له ئيستاوه ئزانم كي ايباته وه ديارة نيگار يه كمة,لە ئێستاوە ئەزانم کێ ئەیباتەوە، دیارە نیگار یەکەمە
41
+ ئؤنده به فره ى بي نه كه ت ئؤنده به دارو ديوار كه ت,ئەوەندە بەفرەی پێنەکەوت، ئەوەندەی بە دارودیوار کەوت
42
+ سرنجي بؤق راكيشي جونكه هه رله بؤق اجيت,سەرنجی بۆق ڕائەکێشیت، چونکە هەر لە بۆق ئەچیت.
43
+ لاي هه ديك بياو ا��افره تانه به جاك ده زانن,لای هەندێک پیاو، ئەو ئافرەتانە بە چاک دەزانن
44
+ لاندكرؤز كو دةدزری،،،،,لاندگرۆزەر کو دەدزرێ؟
45
+ هة ناره يه كة م دابي,هەنارە یەکەم دەبی
46
+ بة قسةى من دةكةى هةسته بةخوت بروة مالى با دةرت نةكةن,بەقسەی من دەکەی، هەستە بۆ خۆت بڕۆوە ماڵێ، با دەرت نەکەن
47
+ ئةو كچة رزاي زور قورسة,ئەو کچە ڕەزای زۆر قوڕسە
48
+ جاخؤدوتؤپةلى تئ گرن باشترة,جا خۆ دوو تۆپەڵەی تێگرن باشترە
49
+ روى باوكى ئه وه ره شبيت (ته مارا)ى كرد به بيشكه شكار,ڕووی باوکی ئەوە ڕەش بێت (تەمارا) ی کردە پێشکەشکار.
50
+ له پيرلؤ خوئيريتر تؤيت,لە پێڕلۆ خوێڕیتر تۆیت.
51
+ تخوا ئةمة شتة سةيري دةكةن,تخوا ئەمە شتە سەیری دەکەن
52
+ لة داخي دواني وا حةزةكةم ئةم ولاتة جئ بيلم,لەداخی دووانی وا، حەز ئەکەم ئەم وڵاتە جێبێلم
53
+ ئينشةلا كوراكاو دهبةنةوة,ئینشائەڵڵا کوڕەکان دەبەنەوە
54
+ ام كجه بيويسته ببريته نه خؤشخانه ى ده رونى,ئەم کچە پێویستە ببرێتە نەخۆشخانەی دەروونی
55
+ ئه ى بو ئاسايشى هه وليئر دزه گه وره كان ده ستگير نا كريئن.,ئەی بۆ ئاساییشی هەولێر دزەگەورەکان دەزگیر ناکەن؟
56
+ هةر ماعدة مابوو تداخلى بكةن,هەر ماعیدە مابوو تەداخولی بکەن
57
+ به خواتابليي به رنامه يه كي هيجه جاجلوبه ركي ناشيرين,بەخوا تابڵێی بەرنامەیەکی هیچە، جا جلوبەرگی ناشرین
58
+ وةالله شتةكم لة دةست بواية ريك كةنالةكةم دادةخست,وەڵڵا شتەکم لەدەست بوایە ڕێک کەناڵەکەم دادەخست.
59
+ توخوا اوة بةرنامةية,تخوا ئەوە بەرنامەیە؟
60
+ تةنانةت ليرةش غةدرتان لة كةركوك كرد عةمرتان نةمينى,تەنانەت لێرەش غەدرتان لە کەرکوک کرد، عەمرتان نەمینێ.
61
+ بروا بكةن ئةمن بةس تةماشاي ريكارم كرد,بڕوا بکەن ئەمن بەس تەماشای ڕێکارم کرد
62
+ ئةمة فیلبو حكم تاوانبارة جونكة كوتی نابی كةس قةسةبةكات,ئەمە فێڵ بوو عەکەم تاوانبارە، چونکە وتی نابێ کەس قسە بکات.
63
+ داناش وةك يارا بةس فشةفش دةكات هيجيش ناباتةوة,داناش وەک یارا بەس فشەفش دەکات، هیچیش ناباتەوە
64
+ خواية شوكرم بةبةشت لةباتى باران فيتنة دةبري,خودایە شوکرم بە بەشت، لەجیاتیی باران ، فیتنە دەبارێ
65
+ گةردةلول بةخيوي كردوم بيمنتةتم لة رةشةبا,گەردەلول بەخێوی کردووم، بێمنەتتم لە ڕەشەبا.
66
+ اليي شيره به مه ييه كه يه دارماسيحه,ئەلێی شیرە پەمەیەکەیە، دارماسیحە
67
+ زؤر ركم لة پيشكةشكارةكةية اسلوبيى قسةى زؤر ناشرينة,زۆر ڕقم لەپێشکەشکارەکەیە، ئسلوبی قسەکردنی زۆر ناشرینە
68
+ بيكه نينه كه شي له هى به شه ر ناجى,پێکەنینەکەشی لە هیی بەشەر ناچێت
69
+ جاوةرة ئةوكفتةي بخؤي بةونينؤكةوة,جا وەرە ئەو کفتەی بخۆی بەو نینۆکەوە
70
+ امه له مه ريخ ده زي,ئەمە لەمەریخ دەژی
71
+ نرخي بۆ خۆمان جەندە,نرخی بۆ خۆمان چەندە؟
72
+ خواي بتكات قورباني كجيكي عه شاير,خودا بتکات بە قوربانی کچێکی عەشایەر
73
+ خوت فيره قسه بكه اوجا به رنامه ى پيشكه ش بكه,خۆت فێرە قسە بکە، ئەوجا بەرنامە پێشکەش بکە
74
+ كورةتووخواتؤشةرم لةخؤت ناكةي,کورە تخوا تۆ شەرم لە خۆت ناکەی؟
75
+ دروبوو اوقسه يه,درۆ بوو ئەو قسەیە
76
+ ئه مه يه كم جارمه به بينم دعبا كوراني بلى كه ناشرينه,ئەمە یەکەم جارمە ببینم دەعبا گۆرانی بڵێ، کە ناشرینە!
77
+ وةره خو نةكوشژة لةبةرئةم ريكلامه,وەرە خۆت مەکوژە لەبەر ئەم ڕیکلامە
78
+ کە ناڵێ بێ رە وشتە کان,کەناڵی بێڕەوشتەکان.
79
+ برؤ كن بروا ستايلي بابلؤكت بكا,بڕۆ کن بڕوا ستایلی با بلۆکت بکات
80
+ عةزةلات فشةيه شةرتي جاو و بروويە,عەزەلات فشەیە، شەرتی چاو و برۆیە
81
+ ام ريكلام بيتامه ش تواو نابي,ئەم ڕیک��امە بێتامەش تەواو نابێ
82
+ دةستان خوش بيت هةر سةركةوتو بن,دەستتان خۆش بێت، هەر سەرکەوتوو بن
83
+ وةلا عةينةن راستية امن بروام كرد,وەڵڵا ئەڵێی ڕاستییە، ئەمن بڕوام کرد.
84
+ بوية كورد هةموي لةناو جوو كةسي تةندروست نماية,بۆیە کورد هەمووی لەناو چوو، کەسی تەندروست نەمایە
85
+ باشترين راهينه ر له ميزووي توبي بي,باشترین ڕاهێنەر لە مێژووی تۆپی پێ
86
+ ژن دوژمني ژنه له به رچي ژن قه بول ناكت پياو ژنيكي دي بينت,ژن دوژمنی ژنە، ژەبەرچی ژن قەبول ناکات پیاو ژنێکی تر بێنێت؟
87
+ چاڤي به ميسي وتوه تؤله ي ئيمه له ريال بكه وه,چاڤی بە مێسی وتوە: تۆڵەی ئێمە لە ڕیاڵ بکەوە
88
+ هةموو روز دةموو جاو ئةم شتة ناشرينة ئةبيني, هەموو ڕۆژێ دەموو چاو ئەم شتە ناشرینە ئەبینی
89
+ نيشانةي ژني خراپة,نیشانەی ژنی خراپە
90
+ ده رئن اوه ي او كفتانه خاري له طه واريه,دەرێن ئەوەی ئەو کفتانەی خوارد لە تەوارییە
91
+ شكلئ اوئ ديكه ي ده ت,شکڵی ئەوەی دیکە دەدات
92
+ وةرزي تر من ارؤم من ببينن لة وي,وەرزێکی تر من ئەڕۆم، من ببینن لەوێ
93
+ چه تاليشي پيه چاوه ري شوتي كاله كي بؤببه ن,چەتاڵیشی پێیە، چاوەڕێیە شووتی و کاڵەکی بۆ ببەن
94
+ زورناخوشة دةني,زۆر ناخۆشە دەنگی
95
+ كةواتة سى ريمةكة باتوشى گوناحنةبى,کەواتە سەیری مەکە باتوشی گوناح نەبیت
96
+ دانا بةو زةعيفيةي خوى باشتره له تو,دانا بەو زەعیفییەی خۆی باشترە لە تو
97
+ وه لله كجه كي قشتؤكه يه بي ده كري خؤشم ده وي من زور,وەڵڵا کچەکی قشتۆکەیە، پێی دەکرێ، خۆشم دەوێ من زۆر
98
+ بةراستي بةرنامةكة ئةمجارة زوور جياوازة و جوانه,بەڕاستی بەرنامەکە ئەم جارە زۆر جیاوازە و جوانە
99
+ سيناريوه فيشه ك ته قاندن ئاو هايه,سیناریۆیە، فیشەکتەقاندن ئەوهایە؟
100
+ وةلا ئةوة هةمووي كوري فةقيرة شتيكي زؤر خراب اكةن,وەڵڵا ئەوە هەمووی کوڕی فەقیرە، شتێکی زۆر خراپ ئەکەن
101
+ سودي أيوه جيه لؤ ميلت,سوودی ئێوە چییە بۆ میللەت؟
data/asosoft_benchmark.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/words.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [
2
+
3
+ ]
eval.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+ python3 run_summarization.py \
3
+ --model_name_or_path "razhan/bart-kurd-spell-base" \
4
+ --do_eval \
5
+ --validation_file data/asosoft_benchmark.csv \
6
+ --output_dir /tmp \
7
+ --overwrite_output_dir \
8
+ --per_device_eval_batch_size=32 \
9
+ --predict_with_generate \
10
+ --logging_steps="1" \
11
+ --max_target_length=1024 \
12
+ --max_source_length=1024 \
13
+ --report_to="none"
inspect_data.ipynb ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from datasets import load_dataset\n",
10
+ "import pandas as pd\n",
11
+ "from utils import *\n"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "rste = load_dataset(\"razhan/rste\", split=\"train\")"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "df_rste = rste.to_pandas()\n",
30
+ "df = df_rste\n",
31
+ "pd.set_option('display.max_colwidth', None)"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": []
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "df['clean_text'] = df['text'].apply(process_text)\n",
48
+ "\n",
49
+ "# df['contains_non_kurdish'] = df[\"text\"].apply(contains_non_kurdish_characters)\n",
50
+ "# print(df['contains_non_kurdish'].sum())\n",
51
+ "# df[df['contains_non_kurdish'] == True]"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "df['clean_text'] = df['clean_text'].apply(keep_kurdish_characters)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": []
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "# df[df['contains_non_kurdish'] == False]['clean_text'].to_csv(\"data/data.ckb.txt\", index=False, header=False)\n",
77
+ "df['clean_text'].to_csv(\"data/data.ckb.txt\", index=False, header=False)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "df[df['text'].str.contains('ھ')]\n",
87
+ "indices_with_substring = df[df['text'].str.contains('ھ')].index\n",
88
+ "# print(indices_with_substring)\n",
89
+ "df.loc[indices_with_substring]"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "all_text = ''.join(df[\"text\"])\n",
99
+ "\n",
100
+ "unique_characters = set(all_text)\n",
101
+ "\n",
102
+ "print(\"Unique characters:\", unique_characters)\n",
103
+ "print(\"Number of unique characters:\", len(unique_characters))"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "all_text = ''.join(df[\"clean_text\"])\n",
113
+ "# all_text = ''.join(df[df['contains_non_kurdish'] == False]['clean_text'])\n",
114
+ "unique_characters = set(all_text)\n",
115
+ "unique_punctuations = extract_punctuation(all_text)\n",
116
+ "print(\"Unique characters:\", unique_characters)\n",
117
+ "print(\"Number of unique characters:\", len(unique_characters))\n",
118
+ "print(\"Unique punctuations:\", unique_punctuations)"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "df['contains_non_kurdish'] = df[\"text\"].apply(contains_non_kurdish_characters)"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "len(unique_punctuations)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "# df = pd.read_csv(\"asosoft_test_punc.csv\")\n",
146
+ "# df['summary'] = df['summary'].apply(process_text)\n",
147
+ "# df.to_csv(\"asosoft_test_clean.csv\", index=False)"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "df"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": []
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "oscar_dataset = load_dataset(\"oscar-corpus/OSCAR-2301\", language=\"ckb\", split='train', token=True)\n",
173
+ "wiki_dataset = load_dataset(\"wikipedia\", language=\"ckb\", date=\"20231120\", split='train', beam_runner='DirectRunner')\n",
174
+ "\n",
175
+ "df_oscar = oscar_dataset.to_pandas()\n",
176
+ "df_wiki = wiki_dataset.to_pandas()\n",
177
+ "df = pd.concat([df_oscar, df_wiki], ignore_index=True)\n"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "df"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "df[\"text\"] = df[\"text\"].apply(process_text)"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "# text = df[\"text\"].str.cat(sep=\"\\n\")"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "df['clean_text'] = df['text'].apply(keep_kurdish_characters)"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "df[\"clean_text\"] = df[\"clean_text\"].apply(process_text)"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "df['contains_non_kurdish'] = df[\"clean_text\"].apply(contains_non_kurdish_characters)\n",
232
+ "print(df['contains_non_kurdish'].sum())\n",
233
+ "# df[df['contains_non_kurdish'] == True].iloc[0]['clean_text']"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "df['repeated_ngram'] = df['clean_text'].apply(lambda x: contains_repeated_ngram(x, 10))\n",
243
+ "print(df['repeated_ngram'].sum())"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "# drop rows where repeated_ngram are True\n",
253
+ "df = df[df['repeated_ngram'] == False]\n"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "df[df['repeated_ngram'] == True].iloc[0]['clean_text']"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "all_text = \"\".join(df[\"clean_text\"])\n",
272
+ "# all_text = ''.join(df[df['contains_non_kurdish'] == False]['clean_text'])\n",
273
+ "unique_characters = set(all_text)\n",
274
+ "unique_punctuations = extract_punctuation(all_text)\n",
275
+ "print(\"Unique characters:\", unique_characters)\n",
276
+ "print(\"Number of unique characters:\", len(unique_characters))\n",
277
+ "print(\"Unique punctuations:\", unique_punctuations)"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "from nltk.tokenize import sent_tokenize\n",
287
+ "data = []\n",
288
+ "for i, row in df.iterrows():\n",
289
+ " sentences = tokenizer_ckb.sent_tokenize(row['clean_text'])\n",
290
+ " # sentences = row['clean_text'].split('\\n')\n",
291
+ " # sentences = sent_tokenize(row['clean_text'])\n",
292
+ " sentences = [sent_tokenize(s) for s in sentences]\n",
293
+ " # flatten list of lists\n",
294
+ " sentences = [item for sublist in sentences for item in sublist]\n",
295
+ " # split on period and keep the period\n",
296
+ " sentences = [s.split('.') for s in sentences]\n",
297
+ " sentences = [item for sublist in sentences for item in sublist]\n",
298
+ "\n",
299
+ " sentences = [s + '.' for s in sentences]\n",
300
+ " data.extend(sentences)\n",
301
+ " # if i == 5:\n",
302
+ " # break"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "print(len(data))"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "# longest line in data\n",
321
+ "max_line = max(data, key=len)"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": null,
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "len(max_line.split())"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "max_line"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "# calulate the length of each line in the data and take the average\n",
349
+ "lengths = [len(line.split()) for line in data]\n",
350
+ "avg_length = sum(lengths) / len(lengths)\n",
351
+ "print(avg_length)"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "# give me all the lines above 20 words\n",
361
+ "long_lines = [line for line in data if len(line.split()) > 25]\n",
362
+ "print(len(long_lines))"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "# Write sentences to file\n",
372
+ "with open(\"data/oscar_wiki.ckb.txt\", \"w\") as f:\n",
373
+ " for sentence in data:\n",
374
+ " f.write(sentence + \"\\n\")"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": []
383
+ }
384
+ ],
385
+ "metadata": {
386
+ "kernelspec": {
387
+ "display_name": "Python 3",
388
+ "language": "python",
389
+ "name": "python3"
390
+ },
391
+ "language_info": {
392
+ "codemirror_mode": {
393
+ "name": "ipython",
394
+ "version": 3
395
+ },
396
+ "file_extension": ".py",
397
+ "mimetype": "text/x-python",
398
+ "name": "python",
399
+ "nbconvert_exporter": "python",
400
+ "pygments_lexer": "ipython3",
401
+ "version": "3.11.6"
402
+ },
403
+ "orig_nbformat": 4
404
+ },
405
+ "nbformat": 4,
406
+ "nbformat_minor": 2
407
+ }
prepare_data/constants.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARABIC_CHARS = 'دصضذطكثنتالبيسجحإأآشظمغفقةىرؤءئزوخهع'
2
+ KURDISH_CHARS = 'ئابپتجچحخدرڕزژسشعغفڤقکگلڵمنهەوووۆیێ'
3
+ VALID_PUNCS = '\?؟\.\\\/,،«»\-:'
4
+
5
+ ARABIC_PUCTUATIONS = "،؛۔٫٪؟"
6
+ CKB_PUNCTUATIONS = "!.:;?،؛؟«»"
7
+
8
+ NUMBERS = '٠١٢٣٤٥٦٧٨٩'
9
+ SPECIAL = ' '
10
+
11
+ NORMLIZER_MAPPER = {
12
+ 'ﻹ': 'لإ',
13
+ 'ﻷ': 'لأ',
14
+ 'ﻵ': 'لآ',
15
+ 'ﻻ': 'لا'
16
+ }
17
+ VALID_CHARS = KURDISH_CHARS + SPECIAL + NUMBERS + CKB_PUNCTUATIONS
18
+
19
+
20
+ KEYBOARD_KEYS = [
21
+ 'قوەرتیئحۆپ',
22
+ 'اسدفگهژکل',
23
+ 'زخجڤبنم'
24
+ ]
25
+ KEYBOARD_BLANK = '_'
prepare_data/generate_dataset.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from string import punctuation
3
+ import re
4
+ import os
5
+ from transformers import AutoTokenizer
6
+ from tqdm import tqdm
7
+ from typing import List
8
+ from constants import KURDISH_CHARS, KEYBOARD_BLANK, KEYBOARD_KEYS, NUMBERS
9
+
10
+
11
+ def tokenizer_check_if_text_too_long(text, tokenizer, max_length):
12
+ data = tokenizer.batch_encode_plus([text],max_length=max_length,truncation=True,return_overflowing_tokens=True )
13
+ if len(data["input_ids"]) > 1:
14
+ return True
15
+ else:
16
+ return False#, len(data["input_ids"][0])
17
+
18
+ def delete_characters(text, char_delete_percentage=0.01):
19
+ modifyed_line = []
20
+ for char in text:
21
+ if random.random() > char_delete_percentage or char in NUMBERS:
22
+ modifyed_line.append(char)
23
+ return "".join(modifyed_line)
24
+
25
+ def insert_characters(text, augmentation_probability=0.01):
26
+ modifyed_line = []
27
+ for char in text:
28
+ if random.random() <= augmentation_probability and char not in NUMBERS:
29
+ modifyed_line.append(random.choice(KURDISH_CHARS))
30
+ modifyed_line.append(char)
31
+ return "".join(modifyed_line)
32
+
33
+ def replace_characters(text, augmentation_probability=0.01):
34
+ modifyed_line = []
35
+ for char in text:
36
+ if random.random() <= augmentation_probability and char not in NUMBERS:
37
+ modifyed_line.append(random.choice(KURDISH_CHARS))
38
+ else:
39
+ modifyed_line.append(char)
40
+ return "".join(modifyed_line)
41
+
42
+ def random_neighbor_replace(line: str, keyboard_rows: List[str], blank: str) -> str:
43
+ lines = keyboard_rows
44
+ n_rows = len(keyboard_rows)
45
+ _mapper = {}
46
+
47
+ def __get_left(row_idx: int, col_idx: int) -> List[str]:
48
+ if col_idx == 0:
49
+ return []
50
+ return [lines[row_idx][col_idx - 1]]
51
+
52
+ def __get_right(row_idx: int, col_idx: int) -> List[str]:
53
+ if col_idx == (len(lines[row_idx]) - 1):
54
+ return []
55
+ return lines[row_idx][col_idx + 1]
56
+
57
+ def __get_upper(row_idx: int, col_idx: int) -> List[str]:
58
+ if row_idx == 0:
59
+ return []
60
+ line = lines[row_idx - 1]
61
+ start = max(0, col_idx - 1)
62
+ end = min(len(line), col_idx + 2)
63
+ return list(line[start: end])
64
+
65
+ def __get_lower(row_idx: int, col_idx: int) -> List[str]:
66
+ if row_idx == (n_rows - 1):
67
+ return []
68
+ line = lines[row_idx + 1]
69
+ start = max(0, col_idx - 1)
70
+ end = min(len(line), col_idx + 2)
71
+ return list(line[start: end])
72
+
73
+ funcs = [__get_left, __get_right, __get_upper, __get_lower]
74
+ for row_idx in range(n_rows):
75
+ for col_idx in range(len(lines[row_idx])):
76
+ items = []
77
+ for func in funcs:
78
+ items.extend(func(row_idx, col_idx))
79
+ items = list(filter(lambda x: x != blank, items))
80
+ char = lines[row_idx][col_idx]
81
+ _mapper[char] = items.copy()
82
+
83
+ def get_char(char: str) -> str:
84
+ if char not in _mapper:
85
+ return char
86
+ return random.choice(_mapper[char])
87
+
88
+ length = len(line)
89
+ if length == 0:
90
+ length = 1
91
+ idx = random.randint(0, length - 1)
92
+ return line[:idx] + get_char(line[idx]) + line[idx + 1:]
93
+ def lower_case_words(text, augmentation_probability=0.5):
94
+ modifyed_line = []
95
+ for word in text.split():
96
+ if word[0].islower() == False and random.random() <= augmentation_probability:
97
+ word = word.lower()
98
+ modifyed_line.append(word)
99
+ return " ".join(modifyed_line)
100
+
101
+
102
+ clean_chars = re.compile(r'[^A-Za-zöäüÖÄÜß,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
103
+ def cleanup(text):
104
+ text = clean_chars.sub('', text)
105
+ #print("bug: somehow all numbers are removed - this is might be due to this regex")
106
+ #exit()
107
+ #text = text.replace("\n", "")
108
+ #text = text.replace('"','\\"')
109
+ return text
110
+
111
+ clean_punctuation = re.compile(r"(?<!\d)[.,;:'?؟.!()؟،»«](?!\d)")
112
+ def remove_punctuation(text):
113
+ """Remove all punctuation from string, except if it's between NUMBERS"""
114
+ return clean_punctuation.sub("", text)
115
+
116
+ def combine_sentences(text, sentences, augmentation_probability = 1):
117
+ if random.random() < augmentation_probability:
118
+ sentences_to_sample = random.randint(0,10)
119
+ augmentation_sentences = random.sample(sentences,sentences_to_sample)
120
+ return text + " " + " ".join(augmentation_sentences)
121
+ else:
122
+ return text
123
+
124
+ def delete_word(text, augmentation_probability = 0.001):
125
+ if random.random() < augmentation_probability:
126
+ words = text.split()
127
+ if len(words) < 3:
128
+ # do not delete word in short text, as there will be no context to guess the word
129
+ return text
130
+ word_to_remove = random.randint(0,len(words)-1)
131
+ words.pop(word_to_remove)
132
+ return " ".join(words)
133
+ else:
134
+ return text
135
+
136
+
137
+ if __name__ == "__main__":
138
+ data_file = "data/data.txt" #"data/en.wikidump.processed.24m.txt" #
139
+ language = "ckb" # "wikidump.24m.en"
140
+ num_lines = sum(1 for line in open(data_file,'r'))
141
+ print("Number of lines:",num_lines)
142
+ with open(data_file,'r') as file:
143
+ sentences = file.readlines(int(num_lines*0.5))
144
+ # sentences = [cleanup(sentence) for sentence in sentences]
145
+
146
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
147
+ tokenizer = AutoTokenizer.from_pretrained("./tokenizer")
148
+ with open(language+".csv","w",encoding='utf-8') as output:
149
+ with open(data_file,'r') as file:
150
+ for line in tqdm(file, total=num_lines):
151
+ # line = cleanup(line)
152
+ if len(line) < 1:
153
+ continue
154
+ line = combine_sentences(line,sentences)
155
+ if tokenizer_check_if_text_too_long(line,tokenizer,max_length=1024):
156
+ print(f"skipping line as its too long ({len(line)}):\n"+line)
157
+ continue
158
+
159
+ if random.random() >0.02:
160
+ # we will leave 2% of the data untouched, to teach the
161
+ # model, not to "overact" on the texts
162
+ new_line = delete_word(line)
163
+ new_line = delete_characters(new_line)
164
+ new_line = insert_characters(new_line)
165
+ new_line = replace_characters(new_line)
166
+ new_line = random_neighbor_replace(new_line, KEYBOARD_KEYS, KEYBOARD_BLANK)
167
+ new_line = remove_punctuation(new_line)
168
+ else:
169
+ new_line = line
170
+ output.write(f'"{new_line.strip()}","{line.strip()}"\n')
171
+ os.system(f"echo \"text,summary\" > {language}.train.csv")
172
+ num_lines = sum(1 for line in open(f"{language}.csv",'r'))
173
+ os.system(f"head -n {num_lines-2000} {language}.csv >> {language}.train.csv")
174
+ os.system(f"echo \"text,summary\" > {language}.test.csv")
175
+ os.system(f"tail -n 2000 {language}.csv >> {language}.test.csv")
176
+
177
+
178
+
prepare_data/helpers.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import json
3
+ import math
4
+ import re
5
+ from typing import List, Union
6
+ from pathlib import Path
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def load_text_file(
12
+ file_path: Union[Path, str],
13
+ encoding='utf-8',
14
+ *args, **kwargs
15
+ ) -> str:
16
+ with open(file_path, 'r', encoding=encoding) as f:
17
+ data = f.read()
18
+ return data
19
+
20
+
21
+ def save_text_file(
22
+ file_path: Union[Path, str],
23
+ data: str,
24
+ encoding='utf-8'
25
+ ) -> str:
26
+ with open(file_path, 'w', encoding=encoding) as f:
27
+ data = f.write(data)
28
+ return data
29
+
30
+
31
+ def remove_long_spaces(line: str) -> str:
32
+ return re.sub('\s{2,}', ' ', line)
33
+
34
+
35
+ @lru_cache(maxsize=2)
36
+ def get_positionals(max_length: int, d_model: int) -> Tensor:
37
+ """Create Positionals tensor to be added to the input
38
+ Args:
39
+ max_length (int): The maximum length of the positionals sequence.
40
+ d_model (int): The dimensionality of the positionals sequence.
41
+ Returns:
42
+ Tensor: Positional tensor
43
+ """
44
+ result = torch.zeros(max_length, d_model, dtype=torch.float)
45
+ for pos in range(max_length):
46
+ for i in range(0, d_model, 2):
47
+ denominator = pow(10000, 2 * i / d_model)
48
+ result[pos, i] = math.sin(pos / denominator)
49
+ result[pos, i + 1] = math.cos(pos / denominator)
50
+ return result
51
+
52
+
53
+ def load_json(file_path: Union[Path, str]) -> Union[dict, list]:
54
+ with open(file_path, 'r') as f:
55
+ data = json.load(f)
56
+ return data
57
+
58
+
59
+ def save_json(
60
+ file_path: Union[Path, str], data: Union[dict, list]
61
+ ) -> None:
62
+ with open(file_path, 'w') as f:
63
+ json.dump(data, f)
64
+
65
+
66
+ def get_freq_dict(data: List[str]) -> dict:
67
+ freq = {}
68
+ for item in data:
69
+ for word in item.split(' '):
70
+ if word in freq:
71
+ freq[word] += 1
72
+ else:
73
+ freq[word] = 1
74
+ return freq
75
+
76
+
77
+ def load_state(state_path: Union[Path, str]):
78
+ state = torch.load(state_path)
79
+ model = state['model']
80
+ model = {
81
+ key.replace('module.', ''): value
82
+ for key, value in model.items()
83
+ }
84
+ optimizer = state['optimizer']
85
+ epoch = state['epoch']
86
+ steps = state['steps']
87
+ return model, optimizer, epoch, steps
prepare_data/interfaces.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod, abstractproperty
2
+
3
+
4
+ class IProcess(ABC):
5
+
6
+ @abstractmethod
7
+ def execute():
8
+ pass
9
+
10
+
11
+ class IProcessor(ABC):
12
+
13
+ @abstractmethod
14
+ def run():
15
+ pass
16
+
17
+ @abstractmethod
18
+ def dist_run():
19
+ pass
20
+
21
+
22
+ class ITokenizer(ABC):
23
+
24
+ @abstractmethod
25
+ def ids2tokens(self):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def tokenize(self):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def set_tokenizer(self):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def save_tokenizer(self):
38
+ pass
39
+
40
+ @abstractmethod
41
+ def load_tokenizer(self):
42
+ pass
43
+
44
+ @abstractmethod
45
+ def add_token(self):
46
+ pass
47
+
48
+ @abstractmethod
49
+ def preprocess_tokens(self):
50
+ pass
51
+
52
+ @abstractmethod
53
+ def batch_tokenizer(self):
54
+ pass
55
+
56
+ @abstractproperty
57
+ def vocab_size(self):
58
+ pass
59
+
60
+ @abstractmethod
61
+ def get_tokens(self):
62
+ pass
63
+
64
+
65
+ class ILogger(ABC):
66
+
67
+ @abstractmethod
68
+ def log_step():
69
+ pass
70
+
71
+ @abstractmethod
72
+ def log():
73
+ pass
74
+
75
+ @abstractmethod
76
+ def set_rank():
77
+ pass
78
+
79
+ @abstractmethod
80
+ def log_img():
81
+ pass
82
+
83
+
84
+ class IPredictor(ABC):
85
+
86
+ @abstractmethod
87
+ def predict():
88
+ pass
prepare_data/process_data.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import os
3
+ import re
4
+ import time
5
+ from processors import FilesProcessor, get_text_distorter
6
+ from processes import (
7
+ CharsRemover,
8
+ LengthFilter,
9
+ LinesSplitter,
10
+ LoadFile,
11
+ NumbersFilter,
12
+ OOVFilter,
13
+ RepeatedCharsCollapsor,
14
+ # SoloCharFilter,
15
+ SpacesRemover,
16
+ ValidCharsKeeper,
17
+ WordsFilter,
18
+ WordsNumberFilter,
19
+ CharsNormalizer,
20
+ TokenizerLengthFilter,
21
+ )
22
+ from helpers import load_json, save_text_file
23
+ from typing import Union, List
24
+ from pathlib import Path
25
+ import constants
26
+ import pandas as pd
27
+
28
+
29
+ def get_paths(
30
+ main_dir: Union[Path, str]
31
+ ) -> List[Union[Path, str]]:
32
+ paths = [
33
+ os.path.join(main_dir, file)
34
+ for file in os.listdir(main_dir)
35
+ ]
36
+ return paths
37
+
38
+ def get_path(
39
+ file_path: Union[Path, str]
40
+ ) -> List[Union[Path, str]]:
41
+ if os.path.isfile(file_path):
42
+ return [file_path]
43
+ else:
44
+ raise FileNotFoundError
45
+
46
+
47
+ def get_file_processor(args):
48
+ words = load_json(args.execlude_words_files)
49
+ processes = [
50
+ LoadFile(),
51
+ *[LinesSplitter(sep=sep) for sep in args.sep],
52
+ RepeatedCharsCollapsor(args.max_rep_chars),
53
+ NumbersFilter(),
54
+ # SoloCharFilter(),
55
+ WordsFilter(words),
56
+ ValidCharsKeeper(constants.VALID_CHARS),
57
+ SpacesRemover(),
58
+ WordsNumberFilter(args.min_words, args.max_words),
59
+ # TokenizerLengthFilter(),
60
+ LengthFilter(args.min_len, args.max_len)
61
+ ]
62
+ return FilesProcessor(processes)
63
+
64
+
65
+ def post_process(data: List[str]) -> List[str]:
66
+ lines = []
67
+ for item in data:
68
+ lines.extend(item)
69
+ lines = list(set(lines))
70
+ # lines = OOVFilter(args.max_oov).execute(lines)
71
+ return lines
72
+
73
+
74
+ clean_punctuation = re.compile(r"(?<!\d)[!.:;?،؛؟«» ،؛۔٫٪؟](?!\d)")
75
+
76
+ def remove_punctuation(text):
77
+ """Remove all punctuation from string, except if it's between digits"""
78
+ return clean_punctuation.sub("", text)
79
+
80
+
81
+
82
+ def get_argparser():
83
+ parser = ArgumentParser()
84
+ parser.add_argument(
85
+ '--sep', default=[
86
+ '\n',
87
+ # '\t', '.', '،', ',', '=', ':', '-', '\\', '/'
88
+ ], nargs='+', type=str,
89
+ help='The seperator to be used to split the lines on'
90
+ )
91
+ parser.add_argument(
92
+ '--min_len', default=5, type=int,
93
+ help='The minimum line length to keep'
94
+ )
95
+ parser.add_argument(
96
+ '--max_len', default=1020, type=int,
97
+ help='The maximum line length to keep'
98
+ )
99
+ parser.add_argument(
100
+ '--dist_run', default=False, action='store_true'
101
+ )
102
+ parser.add_argument(
103
+ '--data_path', default='data/data.txt'
104
+ )
105
+ parser.add_argument(
106
+ '--save_path', default='data/clean_data.txt'
107
+ )
108
+ parser.add_argument(
109
+ '--max_rep_chars', default=2
110
+ )
111
+ parser.add_argument(
112
+ '--execlude_words_files', default='data/words.json'
113
+ )
114
+ parser.add_argument(
115
+ '--max_oov', default=100, type=int
116
+ )
117
+ parser.add_argument(
118
+ '--min_words', default=3, type=int
119
+ )
120
+ parser.add_argument(
121
+ '--max_words', default=100, type=int
122
+ )
123
+ parser.add_argument(
124
+ '--dist_ratios', default=[0.05, 0.1, 0.15]
125
+ )
126
+ parser.add_argument(
127
+ '--remove_punc', default=False, action='store_true', help='Remove punctuation of the distorted lines'
128
+ )
129
+ return parser
130
+
131
+
132
+ def main(args) -> None:
133
+ fp = get_file_processor(args)
134
+ files = get_path(args.data_path)
135
+ print('Started!')
136
+ start = time.time()
137
+ if args.dist_run is True:
138
+ print('dist run')
139
+ data = fp.dist_run(files)
140
+ else:
141
+ data = fp.run(files)
142
+ end = time.time()
143
+ print(f'Files Processing completed in {end - start}')
144
+ data = post_process(data)
145
+ sentences = data[: len(data) // 2]
146
+ print("Length of data after post processing", len(data))
147
+ df = None
148
+ for i, ratio in enumerate(args.dist_ratios):
149
+ distorter = get_text_distorter(ratio, sentences)
150
+ # TODO: Don't touch 2 percent of sentences to keep the model from having a high bias towards the noise
151
+
152
+ dist = list(map(distorter.run, data))
153
+ if df is None:
154
+ df = pd.DataFrame({
155
+ 'clean': data,
156
+ f'distorted_{ratio}': dist
157
+ })
158
+ else:
159
+ df[f'distorted_{ratio}'] = dist
160
+ if args.remove_punc is True:
161
+ print("Removing punctuations for the distorted lines")
162
+ for ratio in args.dist_ratios:
163
+ df[f'distorted_{ratio}'] = df[f'distorted_{ratio}'].apply(
164
+ remove_punctuation
165
+ )
166
+ df.to_csv(f'data/data.csv', encoding='utf-8')
167
+ # save_text_file(args.save_path, '\n'.join(data))
168
+
169
+
170
+ if __name__ == '__main__':
171
+ parser = get_argparser()
172
+ args = parser.parse_args()
173
+ main(args)
174
+ num_lines = sum(1 for line in open(f"data/data.csv",'r'))
175
+ os.system(f"echo \"text,summary\" > train.csv")
176
+ # # Only change the first $ variable for different distortion ratios
177
+ # os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $4 \",\" $2}}' data/data.csv >> train.csv")
178
+ # os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $3 \",\" $2}}' data/data.csv >> train.csv")
179
+ os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $5 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv")
180
+ os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $4 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv")
181
+ os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $3 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv")
182
+
183
+ os.system(f"echo \"text,summary\" > test.csv")
184
+ # os.system(f"tail -n 50000 data/data.csv | awk -F',' '{{print $4 \",\" $2}}' >> test.csv")
185
+ # os.system(f"tail -n 50000 data/data.csv | awk -F',' '{{print $3 \",\" $2}}' >> test.csv")
186
+ os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $5 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv")
187
+ os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $4 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv")
188
+ os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $3 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv")
189
+
190
+
191
+
192
+
193
+
prepare_data/processes.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ from typing import List, Union
4
+ from interfaces import IProcess
5
+ from helpers import get_freq_dict, load_text_file, remove_long_spaces
6
+ from transformers import AutoTokenizer
7
+
8
+ class LoadFile(IProcess):
9
+
10
+ def execute(self, file_path: str):
11
+ return load_text_file(
12
+ file_path
13
+ )
14
+
15
+
16
+ class LinesSplitter(IProcess):
17
+ def __init__(self, sep: str) -> None:
18
+ super().__init__()
19
+ self.sep = sep
20
+
21
+ def split(self, line):
22
+ return line.split(self.sep)
23
+
24
+ def execute(self, data: Union[List[str], str]) -> List[str]:
25
+ if isinstance(data, str):
26
+ return data.split(self.sep)
27
+ results = []
28
+ for lines in map(self.split, data):
29
+ results.extend(lines)
30
+ return results
31
+
32
+
33
+ class LengthFilter(IProcess):
34
+ def __init__(
35
+ self, min_length: int, max_length: int
36
+ ) -> None:
37
+ super().__init__()
38
+ self.min_length = min_length
39
+ self.max_length = max_length
40
+
41
+ def execute(self, lines: List[str]):
42
+ return list(filter(
43
+ lambda x: self.min_length <= len(x) <= self.max_length, lines
44
+ ))
45
+
46
+
47
+ class WordsNumberFilter(IProcess):
48
+ def __init__(self, min_words: int, max_words: int) -> None:
49
+ super().__init__()
50
+ self.min_words = min_words
51
+ self.max_words = max_words
52
+
53
+ def _is_valid(self, line: str) -> bool:
54
+ return self.min_words < line.count(' ') < self.max_words
55
+
56
+ def execute(self, lines: List[str]):
57
+ return list(filter(self._is_valid, lines))
58
+
59
+ class TokenizerLengthFilter(IProcess):
60
+ def __init__(self, max_length: int = 1024) -> None:
61
+ super().__init__()
62
+ self.max_length = max_length
63
+ self.tokenizer = AutoTokenizer.from_pretrained("./tokenizer")
64
+
65
+ def _is_valid(self, line: str) -> bool:
66
+ data = self.tokenizer.batch_encode_plus([line], max_length=self.max_length, truncation=True,return_overflowing_tokens=True )
67
+ if len(data["input_ids"]) > 1:
68
+ return True
69
+ else:
70
+ return False
71
+
72
+ def execute(self, lines: List[str]):
73
+ return list(filter(self._is_valid, lines))
74
+
75
+
76
+ class WordsFilter(IProcess):
77
+ def __init__(self, words: List[str]) -> None:
78
+ super().__init__()
79
+ self.words = set(words)
80
+
81
+ def _not_contain(self, line: str) -> bool:
82
+ return not any((
83
+ word in line for word in self.words
84
+ ))
85
+
86
+ def execute(self, lines: List[str]):
87
+ return list(filter(self._not_contain, lines))
88
+
89
+
90
+ class SoloCharFilter(IProcess):
91
+
92
+ def _not_contain(self, line: str) -> bool:
93
+ return re.search('^. | . | .$', line) is None
94
+
95
+ def execute(self, lines: List[str]):
96
+ return list(filter(self._not_contain, lines))
97
+
98
+
99
+ class NumbersFilter(IProcess):
100
+
101
+ def _not_contain(self, line: str) -> bool:
102
+ return re.search('[0-9]+', line) is None
103
+
104
+ def execute(self, lines: List[str]):
105
+ return list(filter(self._not_contain, lines))
106
+
107
+
108
+ class OOVFilter(IProcess):
109
+ def __init__(self, max_oov: int) -> None:
110
+ super().__init__()
111
+ self.max_oov = max_oov
112
+ self.__freq = {}
113
+
114
+ def _is_valid(self, line: str):
115
+ counter = 0
116
+ for word in line.split(' '):
117
+ counter += (self.__freq[word] == 1)
118
+ return counter < self.max_oov
119
+
120
+ def execute(self, lines: List[str]):
121
+ self.__freq = get_freq_dict(lines)
122
+ return list(filter(self._is_valid, lines))
123
+
124
+ # text = ["کوردستان وڵاتی کوردانە هەی هەی هەی هەی", "کورد بوون گەوادیە", "ژیان سەختە"]
125
+ # result = OOVFilter(5).execute(text)
126
+ # print(result)
127
+
128
+
129
+ class CharsRemover(IProcess):
130
+ def __init__(self, chars: str) -> None:
131
+ super().__init__()
132
+ self.pat = f'[{chars}]'
133
+
134
+ def remove(self, line: str) -> str:
135
+ return re.sub(self.pat, '', line)
136
+
137
+ def execute(self, lines: List[str]) -> List[str]:
138
+ return list(map(self.remove, lines))
139
+
140
+
141
+ class RepeatedCharsCollapsor(IProcess):
142
+ def __init__(self, max_repeteion: int) -> None:
143
+ super().__init__()
144
+ self.pat = r"(.)\1{}".format(f"{{{2},}}")
145
+
146
+ def collaps(self, line: str) -> str:
147
+ return re.sub(self.pat, r"\1" * 1, line)
148
+
149
+ def execute(self, lines: List[str]) -> List[str]:
150
+ return list(map(self.collaps, lines))
151
+
152
+
153
+ class ValidCharsKeeper(IProcess):
154
+ def __init__(self, valid_chars: str, rep_with=' ') -> None:
155
+ super().__init__()
156
+ self.valid_chars = valid_chars
157
+ self.rep_with = rep_with
158
+ self.pat = f'[^{self.valid_chars}]'
159
+
160
+ def __keep(self, line: str) -> str:
161
+ return re.sub(self.pat, ' ', line)
162
+
163
+ def execute(self, lines: List[str]) -> List[str]:
164
+ return list(map(self.__keep, lines))
165
+
166
+
167
+ class SpacesRemover(IProcess):
168
+
169
+ def __remove(self, line: str) -> str:
170
+ return remove_long_spaces(line).strip()
171
+
172
+ def execute(self, lines: List[str]):
173
+ return list(map(self.__remove, lines))
174
+
175
+
176
+ class RandomCharsInjector(IProcess):
177
+ def __init__(self, chars: str) -> None:
178
+ super().__init__()
179
+ self.chars = chars
180
+
181
+ def get_char(self) -> str:
182
+ return random.choice(self.chars)
183
+
184
+ def execute(self, line: str):
185
+ length = len(line)
186
+ idx = random.randint(0, length - 1)
187
+ return line[:idx] + self.get_char() + line[idx:]
188
+
189
+ class PunctuationRemover(IProcess):
190
+ def __init__(self) -> None:
191
+ super().__init__()
192
+ self.clean_punctuation = re.compile(r"(?<!\d)[.,;:'?!،.؟؛:»«](?!\d)")
193
+
194
+ def __remove_punctuation(self, text: str):
195
+ """Remove all punctuation from string, except if it's between digits"""
196
+ return self.clean_punctuation.sub("", text)
197
+
198
+ def execute(self, line: str):
199
+ return self.__remove_punctuation(line)
200
+
201
+
202
+ class RandomCharsSwapper(IProcess):
203
+
204
+ def execute(self, line: str) -> str:
205
+ length = len(line)
206
+ idx = random.randint(0, length - 2)
207
+ return line[:idx] + line[idx + 1] + line[idx] + line[idx + 2:]
208
+
209
+
210
+ class RandomCharRemover(IProcess):
211
+
212
+ def execute(self, line: str) -> str:
213
+ length = len(line)
214
+ idx = random.randint(0, length - 1)
215
+ return line[:idx] + line[idx + 1:]
216
+
217
+
218
+ class RandomWordsCollapsor(IProcess):
219
+
220
+ def execute(self, line: str) -> str:
221
+ indices = [
222
+ i for i, char in enumerate(line)
223
+ if char == ' '
224
+ ]
225
+ if len(indices) == 0:
226
+ return line
227
+ idx = random.choice(indices)
228
+ return line[: idx] + line[idx + 1:]
229
+
230
+
231
+ class RandomNeighborReplacer(IProcess):
232
+
233
+ def __init__(self, keyboard_rows: List[str], blank: str) -> None:
234
+ super().__init__()
235
+ self.lines = keyboard_rows
236
+ self.blank = blank
237
+ self.n_rows = len(keyboard_rows)
238
+ self._mapper = {}
239
+ self.set_mapper()
240
+
241
+ def __get_left(
242
+ self, row_idx: int, col_idx: int
243
+ ) -> List[str]:
244
+ if col_idx == 0:
245
+ return []
246
+ return [self.lines[row_idx][col_idx - 1]]
247
+
248
+ def __get_right(
249
+ self, row_idx: int, col_idx: int
250
+ ) -> List[str]:
251
+ if col_idx == (len(self.lines[row_idx]) - 1):
252
+ return []
253
+ return self.lines[row_idx][col_idx + 1]
254
+
255
+ def __get_upper(
256
+ self, row_idx: int, col_idx: int
257
+ ) -> List[str]:
258
+ if row_idx == 0:
259
+ return []
260
+ line = self.lines[row_idx - 1]
261
+ start = max(0, col_idx - 1)
262
+ end = min(len(line), col_idx + 2)
263
+ return list(line[start: end])
264
+
265
+ def __get_lower(
266
+ self, row_idx: int, col_idx: int
267
+ ) -> List[str]:
268
+ if row_idx == (self.n_rows - 1):
269
+ return []
270
+ line = self.lines[row_idx + 1]
271
+ start = max(0, col_idx - 1)
272
+ end = min(len(line), col_idx + 2)
273
+ return list(line[start: end])
274
+
275
+ def set_mapper(self) -> None:
276
+ funcs = [
277
+ self.__get_left,
278
+ self.__get_right,
279
+ self.__get_upper,
280
+ self.__get_lower
281
+ ]
282
+ for row_idx in range(self.n_rows):
283
+ for col_idx in range(len(self.lines[row_idx])):
284
+ items = []
285
+ for func in funcs:
286
+ items.extend(func(row_idx, col_idx))
287
+ items = list(
288
+ filter(lambda x: x != self.blank, items)
289
+ )
290
+ char = self.lines[row_idx][col_idx]
291
+ self._mapper[char] = items.copy()
292
+
293
+ def get_char(self, char: str) -> str:
294
+ if char not in self._mapper:
295
+ return char
296
+ return random.choice(self._mapper[char])
297
+
298
+ def execute(self, line: str) -> str:
299
+ length = len(line)
300
+ idx = random.randint(0, length - 1)
301
+ return line[:idx] + self.get_char(line[idx]) + line[idx + 1:]
302
+
303
+
304
+ class CharsNormalizer(IProcess):
305
+
306
+ def __init__(self, mapper: dict) -> None:
307
+ super().__init__()
308
+ self.mapper = mapper
309
+
310
+ def _normalize(self, line: str) -> str:
311
+ for key, value in self.mapper.items():
312
+ line = line.replace(key, value)
313
+ return line
314
+
315
+ def execute(self, lines: List[str]):
316
+ return list(filter(self._normalize, lines))
317
+
318
+ class SentencePermutation(IProcess):
319
+
320
+ def __init__(self, sentences: List[str], augmentation_probability: float = 1) -> None:
321
+ super().__init__()
322
+ self.sentences = sentences
323
+ self.augmentation_probability = augmentation_probability
324
+
325
+ def _combine(self, text: str) -> str:
326
+ if random.random() < self.augmentation_probability:
327
+ sentences_to_sample = random.randint(0,10)
328
+ augmentation_sentences = random.sample(self.sentences, sentences_to_sample)
329
+ return text + " " + " ".join(augmentation_sentences)
330
+ else:
331
+ return text
332
+
333
+ def execute(self, line: str) -> str:
334
+ # return [self._combine(line) for line in lines]
335
+ return self._combine(line)
prepare_data/processors.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ import constants
3
+ from pathlib import Path
4
+ import random
5
+ from typing import Union, Any, List
6
+ from interfaces import IProcess, IProcessor
7
+ from processes import (
8
+ RandomCharRemover,
9
+ RandomCharsInjector,
10
+ RandomCharsSwapper,
11
+ RandomNeighborReplacer,
12
+ RandomWordsCollapsor,
13
+ PunctuationRemover,
14
+ SentencePermutation,
15
+ )
16
+
17
+
18
+ class FilesProcessor(IProcessor):
19
+ def __init__(
20
+ self, processes: List[IProcess],
21
+ n_dist: int = 32
22
+ ) -> None:
23
+ self.processes = processes
24
+ self.n_dist = n_dist
25
+ self.__dist = False
26
+ self.__cache = []
27
+
28
+ def file_run(self, file: Union[str, Path]) -> Any:
29
+ result = file
30
+ for process in self.processes:
31
+ result = process.execute(result)
32
+ return result
33
+
34
+ def run(
35
+ self,
36
+ files: List[Union[str, Path]]
37
+ ) -> Any:
38
+ result = list(map(self.file_run, files))
39
+ if self.__dist is True:
40
+ self.__cache.append(result)
41
+ return
42
+ return result
43
+
44
+ def _divde(self, data: List[Any]):
45
+ items_per_div = len(data) // self.n_dist
46
+ divs = []
47
+ for i in range(items_per_div):
48
+ start = i * items_per_div
49
+ end = (i + 1) * items_per_div
50
+ if i == (items_per_div - 1):
51
+ end = len(divs)
52
+ divs.append(data[start: end])
53
+ return divs
54
+
55
+ def dist_run(
56
+ self,
57
+ files: List[Union[str, Path]]
58
+ ) -> Any:
59
+ self.__dist = True
60
+ self.__cache = []
61
+ divs = self._divde(files)
62
+ threads = []
63
+ for div in divs:
64
+ t = Thread(target=self.run, args=(div,))
65
+ t.start()
66
+ threads.append(t)
67
+ for t in threads:
68
+ t.join()
69
+ self.__dist = False
70
+ results = []
71
+ for item in self.__cache:
72
+ results.extend(item)
73
+ self.__cache = []
74
+ return results
75
+
76
+
77
+ class TextDistorter(IProcessor):
78
+ def __init__(
79
+ self, ratio: float, processes: List[IProcess]
80
+ ) -> None:
81
+ super().__init__()
82
+ self.ratio = ratio
83
+ self.processes = processes
84
+
85
+ def run(self, line: str) -> str:
86
+ length = len(line)
87
+ n = int(self.ratio * length)
88
+ for _ in range(n):
89
+ line = random.choice(self.processes).execute(line)
90
+ return line
91
+
92
+ def dist_run(self):
93
+ # TODO
94
+ pass
95
+
96
+
97
+ class TextProcessor(IProcessor):
98
+ def __init__(self, processes: List[IProcess]) -> None:
99
+ super().__init__()
100
+ self.processes = processes
101
+
102
+ def run(self, sentence: str):
103
+ for process in self.processes:
104
+ sentence = process.execute(sentence)
105
+ return sentence
106
+
107
+ def dist_run(self, sentence: str) -> str:
108
+ return self.run(sentence)
109
+
110
+
111
+ def get_text_distorter(ratio, sentences: List[str]):
112
+
113
+ return TextDistorter(
114
+ ratio=ratio,
115
+ processes=[
116
+ SentencePermutation(sentences),
117
+ RandomCharsInjector(constants.KURDISH_CHARS),
118
+ RandomCharsSwapper(),
119
+ RandomCharRemover(),
120
+ RandomWordsCollapsor(),
121
+ RandomNeighborReplacer(
122
+ constants.KEYBOARD_KEYS, constants.KEYBOARD_BLANK
123
+ )
124
+ ]
125
+ )
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate >= 0.12.0
2
+ datasets >= 1.8.0
3
+ sentencepiece != 0.1.92
4
+ protobuf
5
+ nltk
6
+ py7zr
7
+ torch >= 2.0.1
8
+ evaluate
9
+ jiwer
run_summarization.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ import warnings
25
+ from dataclasses import dataclass, field
26
+ from typing import Optional
27
+
28
+ import datasets
29
+ import evaluate
30
+ import nltk # Here to have a nice missing dependency error message early on
31
+ import numpy as np
32
+ from datasets import load_dataset
33
+ from filelock import FileLock
34
+
35
+ import transformers
36
+ from transformers import (
37
+ AutoConfig,
38
+ AutoModelForSeq2SeqLM,
39
+ AutoTokenizer,
40
+ DataCollatorForSeq2Seq,
41
+ HfArgumentParser,
42
+ MBart50Tokenizer,
43
+ MBart50TokenizerFast,
44
+ MBartTokenizer,
45
+ MBartTokenizerFast,
46
+ Seq2SeqTrainer,
47
+ Seq2SeqTrainingArguments,
48
+ set_seed,
49
+ )
50
+ from transformers.trainer_utils import get_last_checkpoint
51
+ from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
52
+ from transformers.utils.versions import require_version
53
+
54
+
55
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
56
+ check_min_version("4.33.0.dev0")
57
+
58
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+ try:
63
+ nltk.data.find("tokenizers/punkt")
64
+ except (LookupError, OSError):
65
+ if is_offline_mode():
66
+ raise LookupError(
67
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
68
+ )
69
+ with FileLock(".lock") as lock:
70
+ nltk.download("punkt", quiet=True)
71
+
72
+ # A list of all multilingual tokenizer which require lang attribute.
73
+ MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
74
+
75
+
76
+ @dataclass
77
+ class ModelArguments:
78
+ """
79
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
80
+ """
81
+
82
+ model_name_or_path: str = field(
83
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
84
+ )
85
+ config_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87
+ )
88
+ tokenizer_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90
+ )
91
+ cache_dir: Optional[str] = field(
92
+ default=None,
93
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
94
+ )
95
+ use_fast_tokenizer: bool = field(
96
+ default=True,
97
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
98
+ )
99
+ model_revision: str = field(
100
+ default="main",
101
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
102
+ )
103
+ token: str = field(
104
+ default=None,
105
+ metadata={
106
+ "help": (
107
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
108
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
109
+ )
110
+ },
111
+ )
112
+ use_auth_token: bool = field(
113
+ default=None,
114
+ metadata={
115
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
116
+ },
117
+ )
118
+ trust_remote_code: bool = field(
119
+ default=False,
120
+ metadata={
121
+ "help": (
122
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
123
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will"
124
+ "execute code present on the Hub on your local machine."
125
+ )
126
+ },
127
+ )
128
+ resize_position_embeddings: Optional[bool] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": (
132
+ "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
133
+ "the model's position embeddings."
134
+ )
135
+ },
136
+ )
137
+
138
+
139
+ @dataclass
140
+ class DataTrainingArguments:
141
+ """
142
+ Arguments pertaining to what data we are going to input our model for training and eval.
143
+ """
144
+
145
+ lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
146
+
147
+ dataset_name: Optional[str] = field(
148
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
149
+ )
150
+ dataset_config_name: Optional[str] = field(
151
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
152
+ )
153
+ text_column: Optional[str] = field(
154
+ default=None,
155
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
156
+ )
157
+ summary_column: Optional[str] = field(
158
+ default=None,
159
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
160
+ )
161
+ train_file: Optional[str] = field(
162
+ default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
163
+ )
164
+ validation_file: Optional[str] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": (
168
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
169
+ )
170
+ },
171
+ )
172
+ test_file: Optional[str] = field(
173
+ default=None,
174
+ metadata={
175
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
176
+ },
177
+ )
178
+ overwrite_cache: bool = field(
179
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
180
+ )
181
+ preprocessing_num_workers: Optional[int] = field(
182
+ default=None,
183
+ metadata={"help": "The number of processes to use for the preprocessing."},
184
+ )
185
+ max_source_length: Optional[int] = field(
186
+ default=1024,
187
+ metadata={
188
+ "help": (
189
+ "The maximum total input sequence length after tokenization. Sequences longer "
190
+ "than this will be truncated, sequences shorter will be padded."
191
+ )
192
+ },
193
+ )
194
+ max_target_length: Optional[int] = field(
195
+ default=128,
196
+ metadata={
197
+ "help": (
198
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
199
+ "than this will be truncated, sequences shorter will be padded."
200
+ )
201
+ },
202
+ )
203
+ val_max_target_length: Optional[int] = field(
204
+ default=None,
205
+ metadata={
206
+ "help": (
207
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
208
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
209
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
210
+ "during ``evaluate`` and ``predict``."
211
+ )
212
+ },
213
+ )
214
+ pad_to_max_length: bool = field(
215
+ default=False,
216
+ metadata={
217
+ "help": (
218
+ "Whether to pad all samples to model maximum sentence length. "
219
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
220
+ "efficient on GPU but very bad for TPU."
221
+ )
222
+ },
223
+ )
224
+ max_train_samples: Optional[int] = field(
225
+ default=None,
226
+ metadata={
227
+ "help": (
228
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
229
+ "value if set."
230
+ )
231
+ },
232
+ )
233
+ max_eval_samples: Optional[int] = field(
234
+ default=None,
235
+ metadata={
236
+ "help": (
237
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
238
+ "value if set."
239
+ )
240
+ },
241
+ )
242
+ max_predict_samples: Optional[int] = field(
243
+ default=None,
244
+ metadata={
245
+ "help": (
246
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
247
+ "value if set."
248
+ )
249
+ },
250
+ )
251
+ num_beams: Optional[int] = field(
252
+ default=None,
253
+ metadata={
254
+ "help": (
255
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
256
+ "which is used during ``evaluate`` and ``predict``."
257
+ )
258
+ },
259
+ )
260
+ ignore_pad_token_for_loss: bool = field(
261
+ default=True,
262
+ metadata={
263
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
264
+ },
265
+ )
266
+ source_prefix: Optional[str] = field(
267
+ default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
268
+ )
269
+
270
+ forced_bos_token: Optional[str] = field(
271
+ default=None,
272
+ metadata={
273
+ "help": (
274
+ "The token to force as the first generated token after the decoder_start_token_id."
275
+ "Useful for multilingual models like mBART where the first generated token"
276
+ "needs to be the target language token (Usually it is the target language token)"
277
+ )
278
+ },
279
+ )
280
+
281
+ def __post_init__(self):
282
+ if (
283
+ self.dataset_name is None
284
+ and self.train_file is None
285
+ and self.validation_file is None
286
+ and self.test_file is None
287
+ ):
288
+ raise ValueError("Need either a dataset name or a training, validation, or test file.")
289
+ else:
290
+ if self.train_file is not None:
291
+ extension = self.train_file.split(".")[-1]
292
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
293
+ if self.validation_file is not None:
294
+ extension = self.validation_file.split(".")[-1]
295
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
296
+ if self.test_file is not None:
297
+ extension = self.test_file.split(".")[-1]
298
+ assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
299
+ if self.val_max_target_length is None:
300
+ self.val_max_target_length = self.max_target_length
301
+
302
+
303
+ summarization_name_mapping = {
304
+ "amazon_reviews_multi": ("review_body", "review_title"),
305
+ "big_patent": ("description", "abstract"),
306
+ "cnn_dailymail": ("article", "highlights"),
307
+ "orange_sum": ("text", "summary"),
308
+ "pn_summary": ("article", "summary"),
309
+ "psc": ("extract_text", "summary_text"),
310
+ "samsum": ("dialogue", "summary"),
311
+ "thaisum": ("body", "summary"),
312
+ "xglue": ("news_body", "news_title"),
313
+ "xsum": ("document", "summary"),
314
+ "wiki_summary": ("article", "highlights"),
315
+ "multi_news": ("document", "summary"),
316
+ }
317
+
318
+
319
+ def main():
320
+ # See all possible arguments in src/transformers/training_args.py
321
+ # or by passing the --help flag to this script.
322
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
323
+
324
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
325
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
326
+ # If we pass only one argument to the script and it's the path to a json file,
327
+ # let's parse it to get our arguments.
328
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
329
+ else:
330
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
331
+
332
+ if model_args.use_auth_token is not None:
333
+ warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
334
+ if model_args.token is not None:
335
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
336
+ model_args.token = model_args.use_auth_token
337
+
338
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
339
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
340
+ send_example_telemetry("run_summarization", model_args, data_args)
341
+
342
+ # Setup logging
343
+ logging.basicConfig(
344
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
345
+ datefmt="%m/%d/%Y %H:%M:%S",
346
+ handlers=[logging.StreamHandler(sys.stdout)],
347
+ )
348
+
349
+ if training_args.should_log:
350
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
351
+ transformers.utils.logging.set_verbosity_info()
352
+
353
+ log_level = training_args.get_process_log_level()
354
+ logger.setLevel(log_level)
355
+ datasets.utils.logging.set_verbosity(log_level)
356
+ transformers.utils.logging.set_verbosity(log_level)
357
+ transformers.utils.logging.enable_default_handler()
358
+ transformers.utils.logging.enable_explicit_format()
359
+
360
+ # Log on each process the small summary:
361
+ logger.warning(
362
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
363
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
364
+ )
365
+ logger.info(f"Training/evaluation parameters {training_args}")
366
+
367
+ if data_args.source_prefix is None and model_args.model_name_or_path in [
368
+ "t5-small",
369
+ "t5-base",
370
+ "t5-large",
371
+ "t5-3b",
372
+ "t5-11b",
373
+ ]:
374
+ logger.warning(
375
+ "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
376
+ "`--source_prefix 'summarize: ' `"
377
+ )
378
+
379
+ # Detecting last checkpoint.
380
+ last_checkpoint = None
381
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
382
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
383
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
384
+ raise ValueError(
385
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
386
+ "Use --overwrite_output_dir to overcome."
387
+ )
388
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
389
+ logger.info(
390
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
391
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
392
+ )
393
+
394
+ # Set seed before initializing model.
395
+ set_seed(training_args.seed)
396
+
397
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
398
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
399
+ # (the dataset will be downloaded automatically from the datasets Hub).
400
+ #
401
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
402
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
403
+ #
404
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
405
+ # download the dataset.
406
+ if data_args.dataset_name is not None:
407
+ # Downloading and loading a dataset from the hub.
408
+ raw_datasets = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ cache_dir=model_args.cache_dir,
412
+ token=model_args.token,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ extension = data_args.train_file.split(".")[-1]
419
+ if data_args.validation_file is not None:
420
+ data_files["validation"] = data_args.validation_file
421
+ extension = data_args.validation_file.split(".")[-1]
422
+ if data_args.test_file is not None:
423
+ data_files["test"] = data_args.test_file
424
+ extension = data_args.test_file.split(".")[-1]
425
+ raw_datasets = load_dataset(
426
+ extension,
427
+ data_files=data_files,
428
+ cache_dir=model_args.cache_dir,
429
+ token=model_args.token,
430
+ )
431
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
432
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
433
+
434
+ # Load pretrained model and tokenizer
435
+ #
436
+ # Distributed training:
437
+ # The .from_pretrained methods guarantee that only one local process can concurrently
438
+ # download model & vocab.
439
+ config = AutoConfig.from_pretrained(
440
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
441
+ cache_dir=model_args.cache_dir,
442
+ revision=model_args.model_revision,
443
+ token=model_args.token,
444
+ trust_remote_code=model_args.trust_remote_code,
445
+ )
446
+ tokenizer = AutoTokenizer.from_pretrained(
447
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
448
+ cache_dir=model_args.cache_dir,
449
+ use_fast=model_args.use_fast_tokenizer,
450
+ revision=model_args.model_revision,
451
+ token=model_args.token,
452
+ trust_remote_code=model_args.trust_remote_code,
453
+ )
454
+ model = AutoModelForSeq2SeqLM.from_pretrained(
455
+ model_args.model_name_or_path,
456
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
457
+ config=config,
458
+ cache_dir=model_args.cache_dir,
459
+ revision=model_args.model_revision,
460
+ token=model_args.token,
461
+ trust_remote_code=model_args.trust_remote_code,
462
+ )
463
+
464
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
465
+ # on a small vocab and want a smaller embedding size, remove this test.
466
+ embedding_size = model.get_input_embeddings().weight.shape[0]
467
+ if len(tokenizer) > embedding_size:
468
+ model.resize_token_embeddings(len(tokenizer))
469
+
470
+ if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
471
+ if isinstance(tokenizer, MBartTokenizer):
472
+ model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
473
+ else:
474
+ model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
475
+
476
+ if model.config.decoder_start_token_id is None:
477
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
478
+
479
+ if (
480
+ hasattr(model.config, "max_position_embeddings")
481
+ and model.config.max_position_embeddings < data_args.max_source_length
482
+ ):
483
+ if model_args.resize_position_embeddings is None:
484
+ logger.warning(
485
+ "Increasing the model's number of position embedding vectors from"
486
+ f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
487
+ )
488
+ model.resize_position_embeddings(data_args.max_source_length)
489
+ elif model_args.resize_position_embeddings:
490
+ model.resize_position_embeddings(data_args.max_source_length)
491
+ else:
492
+ raise ValueError(
493
+ f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
494
+ f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
495
+ f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
496
+ " model's position encodings by passing `--resize_position_embeddings`."
497
+ )
498
+
499
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
500
+
501
+ # Preprocessing the datasets.
502
+ # We need to tokenize inputs and targets.
503
+ if training_args.do_train:
504
+ if "train" not in raw_datasets:
505
+ raise ValueError("--do_train requires a train dataset")
506
+ column_names = raw_datasets["train"].column_names
507
+ elif training_args.do_eval:
508
+ if "validation" not in raw_datasets:
509
+ raise ValueError("--do_eval requires a validation dataset")
510
+ column_names = raw_datasets["validation"].column_names
511
+ elif training_args.do_predict:
512
+ if "test" not in raw_datasets:
513
+ raise ValueError("--do_predict requires a test dataset")
514
+ column_names = raw_datasets["test"].column_names
515
+ else:
516
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
517
+ return
518
+
519
+ if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
520
+ assert (
521
+ data_args.lang is not None
522
+ ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
523
+
524
+ tokenizer.src_lang = data_args.lang
525
+ tokenizer.tgt_lang = data_args.lang
526
+
527
+ # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
528
+ # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
529
+ forced_bos_token_id = (
530
+ tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
531
+ )
532
+ model.config.forced_bos_token_id = forced_bos_token_id
533
+
534
+ # Get the column names for input/target.
535
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
536
+ if data_args.text_column is None:
537
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
538
+ else:
539
+ text_column = data_args.text_column
540
+ if text_column not in column_names:
541
+ raise ValueError(
542
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
543
+ )
544
+ if data_args.summary_column is None:
545
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
546
+ else:
547
+ summary_column = data_args.summary_column
548
+ if summary_column not in column_names:
549
+ raise ValueError(
550
+ f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
551
+ )
552
+
553
+ # Temporarily set max_target_length for training.
554
+ max_target_length = data_args.max_target_length
555
+ padding = "max_length" if data_args.pad_to_max_length else False
556
+
557
+ if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
558
+ logger.warning(
559
+ "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
560
+ f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
561
+ )
562
+
563
+ def preprocess_function(examples):
564
+ # remove pairs where at least one record is None
565
+
566
+ inputs, targets = [], []
567
+ for i in range(len(examples[text_column])):
568
+ if examples[text_column][i] and examples[summary_column][i]:
569
+ inputs.append(examples[text_column][i])
570
+ targets.append(examples[summary_column][i])
571
+
572
+ inputs = [prefix + inp for inp in inputs]
573
+ model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
574
+
575
+ # Tokenize targets with the `text_target` keyword argument
576
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
577
+
578
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
579
+ # padding in the loss.
580
+ if padding == "max_length" and data_args.ignore_pad_token_for_loss:
581
+ labels["input_ids"] = [
582
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
583
+ ]
584
+
585
+ model_inputs["labels"] = labels["input_ids"]
586
+ return model_inputs
587
+
588
+ if training_args.do_train:
589
+ train_dataset = raw_datasets["train"]
590
+ if data_args.max_train_samples is not None:
591
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
592
+ train_dataset = train_dataset.select(range(max_train_samples))
593
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
594
+ train_dataset = train_dataset.map(
595
+ preprocess_function,
596
+ batched=True,
597
+ num_proc=data_args.preprocessing_num_workers,
598
+ remove_columns=column_names,
599
+ load_from_cache_file=not data_args.overwrite_cache,
600
+ desc="Running tokenizer on train dataset",
601
+ )
602
+
603
+ if training_args.do_eval:
604
+ max_target_length = data_args.val_max_target_length
605
+ eval_dataset = raw_datasets["validation"]
606
+ if data_args.max_eval_samples is not None:
607
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
608
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
609
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
610
+ eval_dataset = eval_dataset.map(
611
+ preprocess_function,
612
+ batched=True,
613
+ num_proc=data_args.preprocessing_num_workers,
614
+ remove_columns=column_names,
615
+ load_from_cache_file=not data_args.overwrite_cache,
616
+ desc="Running tokenizer on validation dataset",
617
+ )
618
+
619
+ if training_args.do_predict:
620
+ max_target_length = data_args.val_max_target_length
621
+ predict_dataset = raw_datasets["test"]
622
+ if data_args.max_predict_samples is not None:
623
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
624
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
625
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
626
+ predict_dataset = predict_dataset.map(
627
+ preprocess_function,
628
+ batched=True,
629
+ num_proc=data_args.preprocessing_num_workers,
630
+ remove_columns=column_names,
631
+ load_from_cache_file=not data_args.overwrite_cache,
632
+ desc="Running tokenizer on prediction dataset",
633
+ )
634
+
635
+ # Data collator
636
+ label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
637
+ data_collator = DataCollatorForSeq2Seq(
638
+ tokenizer,
639
+ model=model,
640
+ label_pad_token_id=label_pad_token_id,
641
+ pad_to_multiple_of=8 if training_args.fp16 else None,
642
+ )
643
+
644
+ # Metric
645
+ cer = evaluate.load("cer")
646
+ wer = evaluate.load("wer")
647
+ bleu = evaluate.load("bleu")
648
+ chrf = evaluate.load("chrf")
649
+
650
+
651
+ def postprocess_text(preds, labels):
652
+ preds = [pred.strip() for pred in preds]
653
+ labels = [label.strip() for label in labels]
654
+
655
+ # rougeLSum expects newline after each sentence
656
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
657
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
658
+
659
+ return preds, labels
660
+
661
+ def compute_metrics(eval_preds):
662
+ preds, labels = eval_preds
663
+ if isinstance(preds, tuple):
664
+ preds = preds[0]
665
+ # Replace -100s used for padding as we can't decode them
666
+ preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
667
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
668
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
669
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
670
+
671
+ # Some simple post-processing
672
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
673
+ result = {}
674
+ result['cer'] = cer.compute(predictions=decoded_preds, references=decoded_labels)
675
+ result['wer'] = wer.compute(predictions=decoded_preds, references=decoded_labels)
676
+ result['bleu'] = bleu.compute(predictions=decoded_preds, references=decoded_labels)['bleu']
677
+ result['chrF'] = chrf.compute(predictions=decoded_preds, references=decoded_labels)['score']
678
+ result = {k: v if k == 'chrF' else round(v * 100, 4) for k, v in result.items()}
679
+
680
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
681
+ result["gen_len"] = np.mean(prediction_lens)
682
+ return result
683
+
684
+ # Override the decoding parameters of Seq2SeqTrainer
685
+ training_args.generation_max_length = (
686
+ training_args.generation_max_length
687
+ if training_args.generation_max_length is not None
688
+ else data_args.val_max_target_length
689
+ )
690
+ training_args.generation_num_beams = (
691
+ data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
692
+ )
693
+
694
+ # Initialize our Trainer
695
+ trainer = Seq2SeqTrainer(
696
+ model=model,
697
+ args=training_args,
698
+ train_dataset=train_dataset if training_args.do_train else None,
699
+ eval_dataset=eval_dataset if training_args.do_eval else None,
700
+ tokenizer=tokenizer,
701
+ data_collator=data_collator,
702
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
703
+ )
704
+
705
+ # Training
706
+ if training_args.do_train:
707
+ checkpoint = None
708
+ if training_args.resume_from_checkpoint is not None:
709
+ checkpoint = training_args.resume_from_checkpoint
710
+ elif last_checkpoint is not None:
711
+ checkpoint = last_checkpoint
712
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
713
+ trainer.save_model() # Saves the tokenizer too for easy upload
714
+
715
+ metrics = train_result.metrics
716
+ max_train_samples = (
717
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
718
+ )
719
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
720
+
721
+ trainer.log_metrics("train", metrics)
722
+ trainer.save_metrics("train", metrics)
723
+ trainer.save_state()
724
+
725
+ # Evaluation
726
+ results = {}
727
+ if training_args.do_eval:
728
+ logger.info("*** Evaluate ***")
729
+ if isinstance(eval_dataset, dict):
730
+ metrics = {}
731
+ for eval_ds_name, eval_ds in eval_dataset.items():
732
+ dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}")
733
+ metrics.update(dataset_metrics)
734
+ else:
735
+ metrics = trainer.evaluate(metric_key_prefix="eval")
736
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
737
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
738
+
739
+ trainer.log_metrics("eval", metrics)
740
+ trainer.save_metrics("eval", metrics)
741
+
742
+ if training_args.do_predict:
743
+ logger.info("*** Predict ***")
744
+
745
+ predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
746
+ metrics = predict_results.metrics
747
+ max_predict_samples = (
748
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
749
+ )
750
+ metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
751
+
752
+ trainer.log_metrics("predict", metrics)
753
+ trainer.save_metrics("predict", metrics)
754
+
755
+ if trainer.is_world_process_zero():
756
+ if training_args.predict_with_generate:
757
+ predictions = predict_results.predictions
758
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
759
+ predictions = tokenizer.batch_decode(
760
+ predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
761
+ )
762
+ predictions = [pred.strip() for pred in predictions]
763
+ output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
764
+ with open(output_prediction_file, "w") as writer:
765
+ writer.write("\n".join(predictions))
766
+
767
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
768
+ if data_args.dataset_name is not None:
769
+ kwargs["dataset_tags"] = data_args.dataset_name
770
+ if data_args.dataset_config_name is not None:
771
+ kwargs["dataset_args"] = data_args.dataset_config_name
772
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
773
+ else:
774
+ kwargs["dataset"] = data_args.dataset_name
775
+
776
+ if data_args.lang is not None:
777
+ kwargs["language"] = data_args.lang
778
+
779
+ if training_args.push_to_hub:
780
+ trainer.push_to_hub(**kwargs)
781
+ else:
782
+ trainer.create_model_card(**kwargs)
783
+
784
+ return results
785
+
786
+
787
+ def _mp_fn(index):
788
+ # For xla_spawn (TPUs)
789
+ main()
790
+
791
+
792
+ if __name__ == "__main__":
793
+ main()
train.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train BART
2
+ python run_summarization.py \
3
+ --model_name_or_path "facebook/bart-base" \
4
+ --config_name "facebook/bart-base" \
5
+ --tokenizer_name ./tokenizer \
6
+ --do_train \
7
+ --do_eval \
8
+ --evaluation_strategy="epoch" \
9
+ --group_by_length \
10
+ --num_train_epochs=10 \
11
+ --train_file train.csv \
12
+ --validation_file test.csv \
13
+ --preprocessing_num_workers="20" \
14
+ --output_dir ./bart-kurd-spell-base/ \
15
+ --overwrite_output_dir \
16
+ --per_device_train_batch_size=320 \
17
+ --per_device_eval_batch_size=256 \
18
+ --gradient_accumulation_steps=1 \
19
+ --predict_with_generate \
20
+ --logging_steps="100" \
21
+ --save_total_limit="1" \
22
+ --save_strategy="epoch" \
23
+ --report_to="wandb" \
24
+ --run_name="Bart Spell" \
25
+ --max_target_length=1024 \
26
+ --max_source_length=1024 \
27
+ --fp16 \
28
+ --save_safetensors \
29
+ --push_to_hub
30
+
31
+ # Train T5
32
+ # python3 run_summarization.py \
33
+ # --source_prefix "correct: " \
34
+ # --model_name_or_path "google/flan-t5-small" \
35
+ # --config_name "google/flan-t5-small" \
36
+ # --tokenizer_name ./tokenizer \
37
+ # --do_train \
38
+ # --do_eval \
39
+ # --evaluation_strategy="epoch" \
40
+ # --group_by_length \
41
+ # --num_train_epochs=5 \
42
+ # --train_file train.csv \
43
+ # --validation_file test.csv \
44
+ # --preprocessing_num_workers="12" \
45
+ # --output_dir ./t5-kurd-spell-base/ \
46
+ # --overwrite_output_dir \
47
+ # --per_device_train_batch_size=64 \
48
+ # --per_device_eval_batch_size=64 \
49
+ # --gradient_accumulation_steps=1 \
50
+ # --predict_with_generate \
51
+ # --logging_steps="100" \
52
+ # --save_total_limit="1" \
53
+ # --save_strategy="epoch" \
54
+ # --report_to="none" \
55
+ # --run_name="T5 Spell" \
56
+ # --max_target_length=1024 \
57
+ # --max_source_length=1024 \
58
+ # --push_to_hub
59
+ # # --fp16 \
train_tokenizer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer
3
+ import argparse
4
+
5
+
6
+ parser = argparse.ArgumentParser()
7
+
8
+ parser.add_argument("--tokenizer_name", default="facebook/bart-base", help="The name of the tokenizer to train a new one from")
9
+ parser.add_argument("--output_dir", default="tokenizer", type=str, help="Repo id the tokenizer to be pushed to")
10
+ parser.add_argument("--push_to_hub", default=False, action="store_true", help="Push to hub",)
11
+
12
+ args = parser.parse_args()
13
+
14
+
15
+ dataset = load_dataset("oscar-corpus/OSCAR-2301", "ckb", split="train", token=True)
16
+
17
+ def get_training_corpus(batch_size=1000):
18
+ for start_idx in range(0, len(dataset), batch_size):
19
+ samples = dataset[start_idx : start_idx + batch_size]
20
+ yield samples["text"]
21
+
22
+ training_corpus = get_training_corpus()
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
25
+
26
+ tokenizer = tokenizer.train_new_from_iterator(
27
+ training_corpus, vocab_size=len(tokenizer),
28
+ special_tokens_map={
29
+ "eos_token": "</s>",
30
+ "bos_token": "<s>",
31
+ "unk_token": "<unk>",
32
+ "pad_token": "<pad>",
33
+ "mask_token": "<mask>",
34
+ },
35
+ )
36
+
37
+
38
+ tokenizer.save_pretrained(args.output_dir, push_to_hub=args.push_to_hub)
39
+