Spaces:
Runtime error
Runtime error
Add app
Browse files- Pipfile +24 -0
- README.md +3 -7
- app.py +48 -0
- backend/__init__.py +0 -0
- backend/__pycache__/__init__.cpython-38.pyc +0 -0
- backend/__pycache__/aragpt.cpython-38.pyc +0 -0
- backend/__pycache__/modeling_gpt2.cpython-38.pyc +0 -0
- backend/__pycache__/preprocess.cpython-38.pyc +0 -0
- backend/__pycache__/sa_utils.cpython-38.pyc +0 -0
- backend/__pycache__/services.cpython-38.pyc +0 -0
- backend/__pycache__/utils.cpython-38.pyc +0 -0
- backend/aragpt.py +189 -0
- backend/home.py +156 -0
- backend/modeling_gpt2.py +1599 -0
- backend/preprocess.py +736 -0
- backend/processor.py +183 -0
- backend/qa.py +50 -0
- backend/qa_utils.py +163 -0
- backend/sa.py +76 -0
- backend/sa_utils.py +510 -0
- backend/sarcasm.py +21 -0
- backend/services.py +519 -0
- backend/utils.py +64 -0
- packages.txt +2 -0
- requirements.txt +17 -0
- test.py +10 -0
Pipfile
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[source]]
|
2 |
+
url = "https://pypi.org/simple"
|
3 |
+
verify_ssl = true
|
4 |
+
name = "pypi"
|
5 |
+
|
6 |
+
[packages]
|
7 |
+
streamlit = "==0.84.2"
|
8 |
+
arabic-reshaper = "==2.1.3"
|
9 |
+
python-bidi = "==0.4.2"
|
10 |
+
pyarabic = "*"
|
11 |
+
farasapy = "==0.0.14"
|
12 |
+
emoji = "==1.4.2"
|
13 |
+
awesome-streamlit = "*"
|
14 |
+
torch = "==1.9.0"
|
15 |
+
transformers = "==4.10.0"
|
16 |
+
psutil = "==5.8.0"
|
17 |
+
fuzzysearch = "==0.7.3"
|
18 |
+
more-itertools = "==8.9.0"
|
19 |
+
cookiecutter = "*"
|
20 |
+
|
21 |
+
[dev-packages]
|
22 |
+
|
23 |
+
[requires]
|
24 |
+
python_version = "3.8"
|
README.md
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.9.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license: unlicense
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
|
|
1 |
---
|
2 |
+
title: Arabic NLP Demo
|
3 |
+
emoji: ⌨
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
|
|
7 |
app_file: app.py
|
8 |
+
pinned: true
|
|
|
9 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import awesome_streamlit as ast
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from backend.utils import get_current_ram_usage, ga
|
5 |
+
|
6 |
+
import backend.aragpt
|
7 |
+
import backend.home
|
8 |
+
import backend.processor
|
9 |
+
import backend.sa
|
10 |
+
import backend.qa
|
11 |
+
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="TEST", page_icon="📖", initial_sidebar_state="expanded", layout="wide"
|
14 |
+
)
|
15 |
+
|
16 |
+
ga(st.__file__)
|
17 |
+
|
18 |
+
PAGES = {
|
19 |
+
"Home": backend.home,
|
20 |
+
"Arabic Text Preprocessor": backend.processor,
|
21 |
+
"Arabic Language Generation": backend.aragpt,
|
22 |
+
"Arabic Sentiment Analysis": backend.sa,
|
23 |
+
# "Arabic Sarcasm Detection": backend.sarcasm,
|
24 |
+
"Arabic Question Answering": backend.qa,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
st.sidebar.title("Navigation")
|
29 |
+
selection = st.sidebar.radio("Pages", list(PAGES.keys()))
|
30 |
+
|
31 |
+
page = PAGES[selection]
|
32 |
+
# with st.spinner(f"Loading {selection} ..."):
|
33 |
+
ast.shared.components.write_page(page)
|
34 |
+
|
35 |
+
st.sidebar.header("Info")
|
36 |
+
st.sidebar.write("Made by [Wissam Antoun](https://twitter.com/wissam_antoun)")
|
37 |
+
st.sidebar.write(
|
38 |
+
"Pre-trained models are available on [HF Hub](https://huggingface.co/aubmindlab)"
|
39 |
+
)
|
40 |
+
st.sidebar.write(
|
41 |
+
"Models source code available on [GitHub](https://github.com/aub-mind/arabert)"
|
42 |
+
)
|
43 |
+
st.sidebar.write(
|
44 |
+
"App source code available on [GitHub](https://github.com/WissamAntoun/Arabic-NLP-app)"
|
45 |
+
)
|
46 |
+
if st.sidebar.checkbox("Show RAM usage"):
|
47 |
+
ram = get_current_ram_usage()
|
48 |
+
st.sidebar.write("Ram usage: {:.2f}/{:.2f} GB".format(ram[0], ram[1]))
|
backend/__init__.py
ADDED
File without changes
|
backend/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (214 Bytes). View file
|
|
backend/__pycache__/aragpt.cpython-38.pyc
ADDED
Binary file (4.41 kB). View file
|
|
backend/__pycache__/modeling_gpt2.cpython-38.pyc
ADDED
Binary file (42.9 kB). View file
|
|
backend/__pycache__/preprocess.cpython-38.pyc
ADDED
Binary file (17.8 kB). View file
|
|
backend/__pycache__/sa_utils.cpython-38.pyc
ADDED
Binary file (14.5 kB). View file
|
|
backend/__pycache__/services.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
backend/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.28 kB). View file
|
|
backend/aragpt.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from .services import TextGeneration
|
3 |
+
from tokenizers import Tokenizer
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
|
7 |
+
@lru_cache(maxsize=1)
|
8 |
+
def load_text_generator():
|
9 |
+
generator = TextGeneration()
|
10 |
+
generator.load()
|
11 |
+
return generator
|
12 |
+
|
13 |
+
|
14 |
+
generator = load_text_generator()
|
15 |
+
|
16 |
+
qa_prompt = """
|
17 |
+
أجب عن السؤال التالي:
|
18 |
+
"""
|
19 |
+
qa_prompt_post = """ الجواب هو """
|
20 |
+
qa_prompt_post_year = """ في سنة: """
|
21 |
+
|
22 |
+
|
23 |
+
def write():
|
24 |
+
st.markdown(
|
25 |
+
"""
|
26 |
+
<h1 style="text-align:left;">Arabic Language Generation</h1>
|
27 |
+
""",
|
28 |
+
unsafe_allow_html=True,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Sidebar
|
32 |
+
|
33 |
+
# Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
|
34 |
+
st.sidebar.subheader("Configurable parameters")
|
35 |
+
|
36 |
+
model_name = st.sidebar.selectbox(
|
37 |
+
"Model Selector",
|
38 |
+
options=[
|
39 |
+
"AraGPT2-Base",
|
40 |
+
# "AraGPT2-Medium",
|
41 |
+
# "Aragpt2-Large",
|
42 |
+
"AraGPT2-Mega",
|
43 |
+
],
|
44 |
+
index=0,
|
45 |
+
)
|
46 |
+
|
47 |
+
max_new_tokens = st.sidebar.number_input(
|
48 |
+
"Maximum length",
|
49 |
+
min_value=0,
|
50 |
+
max_value=1024,
|
51 |
+
value=100,
|
52 |
+
help="The maximum length of the sequence to be generated.",
|
53 |
+
)
|
54 |
+
temp = st.sidebar.slider(
|
55 |
+
"Temperature",
|
56 |
+
value=1.0,
|
57 |
+
min_value=0.1,
|
58 |
+
max_value=100.0,
|
59 |
+
help="The value used to module the next token probabilities.",
|
60 |
+
)
|
61 |
+
top_k = st.sidebar.number_input(
|
62 |
+
"Top k",
|
63 |
+
value=10,
|
64 |
+
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
65 |
+
)
|
66 |
+
top_p = st.sidebar.number_input(
|
67 |
+
"Top p",
|
68 |
+
value=0.95,
|
69 |
+
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
|
70 |
+
)
|
71 |
+
do_sample = st.sidebar.selectbox(
|
72 |
+
"Sampling?",
|
73 |
+
(True, False),
|
74 |
+
help="Whether or not to use sampling; use greedy decoding otherwise.",
|
75 |
+
)
|
76 |
+
num_beams = st.sidebar.number_input(
|
77 |
+
"Number of beams",
|
78 |
+
min_value=1,
|
79 |
+
max_value=10,
|
80 |
+
value=3,
|
81 |
+
help="The number of beams to use for beam search.",
|
82 |
+
)
|
83 |
+
repetition_penalty = st.sidebar.number_input(
|
84 |
+
"Repetition Penalty",
|
85 |
+
min_value=0.0,
|
86 |
+
value=3.0,
|
87 |
+
step=0.1,
|
88 |
+
help="The parameter for repetition penalty. 1.0 means no penalty",
|
89 |
+
)
|
90 |
+
no_repeat_ngram_size = st.sidebar.number_input(
|
91 |
+
"No Repeat N-Gram Size",
|
92 |
+
min_value=0,
|
93 |
+
value=3,
|
94 |
+
help="If set to int > 0, all ngrams of that size can only occur once.",
|
95 |
+
)
|
96 |
+
|
97 |
+
st.write("#")
|
98 |
+
|
99 |
+
col = st.columns(2)
|
100 |
+
|
101 |
+
col[0].image("images/AraGPT2.png", width=200)
|
102 |
+
|
103 |
+
st.markdown(
|
104 |
+
"""
|
105 |
+
|
106 |
+
<h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3>
|
107 |
+
<h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4>
|
108 |
+
|
109 |
+
<p style="text-align:left;"><p>
|
110 |
+
<p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p>
|
111 |
+
<p style="text-align:right;"><p>
|
112 |
+
""",
|
113 |
+
unsafe_allow_html=True,
|
114 |
+
)
|
115 |
+
|
116 |
+
# col[0].write(
|
117 |
+
# "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)."
|
118 |
+
# )
|
119 |
+
# st.write("## Generate Arabic Text")
|
120 |
+
|
121 |
+
st.markdown(
|
122 |
+
"""
|
123 |
+
<style>
|
124 |
+
p, div, input, label, textarea{
|
125 |
+
text-align: right;
|
126 |
+
}
|
127 |
+
</style>
|
128 |
+
""",
|
129 |
+
unsafe_allow_html=True,
|
130 |
+
)
|
131 |
+
|
132 |
+
prompt = st.text_area(
|
133 |
+
"Prompt",
|
134 |
+
"يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال",
|
135 |
+
)
|
136 |
+
if st.button("Generate"):
|
137 |
+
with st.spinner("Generating..."):
|
138 |
+
generated_text = generator.generate(
|
139 |
+
prompt=prompt,
|
140 |
+
model_name=model_name,
|
141 |
+
max_new_tokens=max_new_tokens,
|
142 |
+
temperature=temp,
|
143 |
+
top_k=top_k,
|
144 |
+
top_p=top_p,
|
145 |
+
repetition_penalty=repetition_penalty,
|
146 |
+
do_sample=do_sample,
|
147 |
+
num_beams=num_beams,
|
148 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
149 |
+
)
|
150 |
+
st.write(generated_text)
|
151 |
+
|
152 |
+
st.markdown("---")
|
153 |
+
st.subheader("")
|
154 |
+
st.markdown(
|
155 |
+
"""
|
156 |
+
<p style="text-align:left;"><p>
|
157 |
+
<h2 style="text-align:left;">Zero-Shot Question Answering</h2>
|
158 |
+
|
159 |
+
<p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p>
|
160 |
+
<p style="text-align:left;"><p>
|
161 |
+
""",
|
162 |
+
unsafe_allow_html=True,
|
163 |
+
)
|
164 |
+
|
165 |
+
question = st.text_input(
|
166 |
+
"Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟"
|
167 |
+
)
|
168 |
+
is_date = st.checkbox("Help the model: Is the answer a date?")
|
169 |
+
if st.button("Answer"):
|
170 |
+
|
171 |
+
prompt2 = qa_prompt + question + qa_prompt_post
|
172 |
+
if is_date:
|
173 |
+
prompt2 += qa_prompt_post_year
|
174 |
+
else:
|
175 |
+
prompt2 += " : "
|
176 |
+
with st.spinner("Thinking..."):
|
177 |
+
answer = generator.generate(
|
178 |
+
prompt=prompt2,
|
179 |
+
model_name=model_name,
|
180 |
+
max_new_tokens=max_new_tokens,
|
181 |
+
temperature=temp,
|
182 |
+
top_k=top_k,
|
183 |
+
top_p=top_p,
|
184 |
+
repetition_penalty=repetition_penalty,
|
185 |
+
do_sample=do_sample,
|
186 |
+
num_beams=num_beams,
|
187 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
188 |
+
)
|
189 |
+
st.write(answer)
|
backend/home.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import awesome_streamlit as ast
|
3 |
+
|
4 |
+
|
5 |
+
def write():
|
6 |
+
st.markdown(
|
7 |
+
"""
|
8 |
+
# Arabic Natural Language Processing
|
9 |
+
|
10 |
+
![visitors](https://visitor-badge.glitch.me/badge?page_id=wissamantoun.arabicnlpapp)
|
11 |
+
|
12 |
+
|
13 |
+
In this HuggingFace space you will be able to test the different Arabic NLP models that my colleges at [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) have built, with some other applications.
|
14 |
+
|
15 |
+
Check the **Navigation bar** to access the apps:
|
16 |
+
- Arabic Text Preprocessor: Test how text imput is treated by our preprocessor
|
17 |
+
- Arabic Language Generation: Generate Arabic text using our AraGPT2 language models
|
18 |
+
- Arabic Sentiment Analysis: Test the senitment analysis model that won the [Arabic Senitment Analysis competition @ KAUST](https://www.kaggle.com/c/arabic-sentiment-analysis-2021-kaust)
|
19 |
+
- Arabic Question Answering: Test our AraELECTRA QA capabilities
|
20 |
+
"""
|
21 |
+
)
|
22 |
+
st.markdown("#")
|
23 |
+
col1, col2, col3 = st.columns(3)
|
24 |
+
|
25 |
+
col1.write("## **AraBERT**")
|
26 |
+
col1.image("images/arabert_logo.png", width=200)
|
27 |
+
|
28 |
+
col2.write("## **AraGPT2**")
|
29 |
+
col2.image("images/AraGPT2.png", width=200)
|
30 |
+
|
31 |
+
col3.write("## **AraElectra**")
|
32 |
+
col3.image("images/AraELECTRA.png", width=200)
|
33 |
+
|
34 |
+
st.markdown(
|
35 |
+
"""
|
36 |
+
|
37 |
+
You can find the more details in the source code and paper linked in our repository on GitHub [repo](https://github.com/aub-mind/arabert).
|
38 |
+
|
39 |
+
## Dataset
|
40 |
+
|
41 |
+
The pretraining data used for the new **AraBERT** model is also used for **AraGPT2 and AraELECTRA**.
|
42 |
+
|
43 |
+
The dataset consists of 77GB or 200,095,961 lines or 8,655,948,860 words or 82,232,988,358 chars (before applying Farasa Segmentation)
|
44 |
+
|
45 |
+
Our large models were train a TPUv3-128 provided by TFRC.
|
46 |
+
|
47 |
+
For the new dataset we added the unshuffled OSCAR corpus, after we thoroughly filter it, to the previous dataset used in AraBERTv1 but with out the websites that we previously crawled:
|
48 |
+
- OSCAR unshuffled and filtered.
|
49 |
+
- [Arabic Wikipedia dump](https://archive.org/details/arwiki-20190201) from 2020/09/01
|
50 |
+
- [The 1.5B words Arabic Corpus](https://www.semanticscholar.org/paper/1.5-billion-words-Arabic-Corpus-El-Khair/f3eeef4afb81223df96575adadf808fe7fe440b4)
|
51 |
+
- [The OSIAN Corpus](https://www.aclweb.org/anthology/W19-4619)
|
52 |
+
- Assafir news articles. Huge thank you for Assafir for the data
|
53 |
+
|
54 |
+
## Models
|
55 |
+
|
56 |
+
Model | HuggingFace Model Name | Size (MB/Params)| Pre-Segmentation | Hardware | Sequence Length | Batch Size | Num of Steps | Total Time (in Days) |
|
57 |
+
---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:
|
58 |
+
AraBERTv0.2-base | [bert-base-arabertv02](https://huggingface.co/aubmindlab/bert-base-arabertv02) | 543MB / 136M | No | TPUv3-8 | 128 /512 | 2560/384 | 1M/ 2M | 36 |
|
59 |
+
AraBERTv0.2-large| [bert-large-arabertv02](https://huggingface.co/aubmindlab/bert-large-arabertv02) | 1.38G / 371M | No | TPUv3-128 | 128 /512 | 13440 / 2056 | 250K / 300K | 7 |
|
60 |
+
AraBERTv2-base| [bert-base-arabertv2](https://huggingface.co/aubmindlab/bert-base-arabertv2) | 543MB / 136M | Yes | TPUv3-8 |128 /512 | 2560 / 384 | 1M / 2M | 36 |
|
61 |
+
AraBERTv2-large| [bert-large-arabertv2](https://huggingface.co/aubmindlab/bert-large-arabertv2) | 1.38G / 371M | Yes | TPUv3-128 |128 /512 | 13440 / 2056| 250K / 300K | 7 |
|
62 |
+
AraBERTv0.1-base| [bert-base-arabertv01](https://huggingface.co/aubmindlab/bert-base-arabertv01) | 543MB / 136M | No | TPUv2-8 |128 /512 |128 / 512 | 900K / 300K| 4 |
|
63 |
+
AraBERTv1-base| [bert-base-arabert](https://huggingface.co/aubmindlab/bert-base-arabert) | 543MB / 136M | Yes | TPUv2-8 |128 /512 |128 / 512 | 900K / 300K| 4 |
|
64 |
+
AraGPT2-base | [aragpt2-base](https://huggingface.co/aubmindlab/aragpt2-base) | 527MB/135M | No | TPUv3-128 | 1024 | 1792 | 125K | 1.5 |
|
65 |
+
AraGPT2-medium | [aragpt2-medium](https://huggingface.co/aubmindlab/aragpt2-medium) | 1.38G/370M | No |TPUv3-8 | 1024 | 80 | 1M | 15 |
|
66 |
+
AraGPT2-large | [aragpt2-large](https://huggingface.co/aubmindlab/aragpt2-large) | 2.98GB/792M | No |TPUv3-128 | 1024 | 256 | 220k | 3 |
|
67 |
+
AraGPT2-mega | [aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) | 5.5GB/1.46B |No |TPUv3-128 | 1024 | 256 | 800K | 9 |
|
68 |
+
AraELECTRA-base-generator | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator) | 227MB/60M | No | TPUv3-8 | 512 | 256 | 2M | 24
|
69 |
+
AraELECTRA-base-discriminator | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) | 516MB/135M | No | TPUv3-8 | 512 | 256 | 2M | 24
|
70 |
+
AraBERTv0.2-Twitter-base| [bert-base-arabertv02-twitter](https://huggingface.co/aubmindlab/bert-base-arabertv02-twitter) | 543MB / 136M | No | V100 | *64* | - | - | - |
|
71 |
+
AraBERTv0.2-Twitter-large| [bert-large-arabertv02-twitter](https://huggingface.co/aubmindlab/bert-large-arabertv02-twitter) | 1.38G / 371M | No | V100 | *64* | - | - | - |
|
72 |
+
|
73 |
+
All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
|
74 |
+
|
75 |
+
# Preprocessing
|
76 |
+
|
77 |
+
You can test the Arabic Preprocessing pipeline in the Arabic Text Preprocessing page.
|
78 |
+
|
79 |
+
It is recommended to apply our preprocessing function before training/testing on any dataset.
|
80 |
+
**Install farasapy to segment text for AraBERT v1 & v2 `pip install farasapy`**
|
81 |
+
|
82 |
+
```python
|
83 |
+
from arabert.preprocess import ArabertPreprocessor
|
84 |
+
|
85 |
+
model_name = "aubmindlab/bert-base-arabertv2"
|
86 |
+
arabert_prep = ArabertPreprocessor(model_name=model_name)
|
87 |
+
|
88 |
+
text = "ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
|
89 |
+
arabert_prep.preprocess(text)
|
90 |
+
>>>"و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
|
91 |
+
```
|
92 |
+
|
93 |
+
You can also use the `unpreprocess()` function to reverse the preprocessing changes, by fixing the spacing around non alphabetical characters, and also de-segmenting if the model selected need pre-segmentation. We highly recommend unprocessing generated content of `AraGPT2` model, to make it look more natural.
|
94 |
+
```python
|
95 |
+
output_text = "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
|
96 |
+
arabert_prep.unpreprocess(output_text)
|
97 |
+
>>>"ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
|
98 |
+
```
|
99 |
+
|
100 |
+
# If you used this model please cite us as :
|
101 |
+
|
102 |
+
## AraBERT
|
103 |
+
Google Scholar has our Bibtex wrong (missing name), use this instead
|
104 |
+
```
|
105 |
+
@inproceedings{antoun2020arabert,
|
106 |
+
title={AraBERT: Transformer-based Model for Arabic Language Understanding},
|
107 |
+
author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
|
108 |
+
booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
|
109 |
+
pages={9}
|
110 |
+
}
|
111 |
+
```
|
112 |
+
## AraGPT2
|
113 |
+
```
|
114 |
+
@inproceedings{antoun-etal-2021-aragpt2,
|
115 |
+
title = "{A}ra{GPT}2: Pre-Trained Transformer for {A}rabic Language Generation",
|
116 |
+
author = "Antoun, Wissam and
|
117 |
+
Baly, Fady and
|
118 |
+
Hajj, Hazem",
|
119 |
+
booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
|
120 |
+
month = apr,
|
121 |
+
year = "2021",
|
122 |
+
address = "Kyiv, Ukraine (Virtual)",
|
123 |
+
publisher = "Association for Computational Linguistics",
|
124 |
+
url = "https://www.aclweb.org/anthology/2021.wanlp-1.21",
|
125 |
+
pages = "196--207",
|
126 |
+
}
|
127 |
+
```
|
128 |
+
|
129 |
+
## AraELECTRA
|
130 |
+
```
|
131 |
+
@inproceedings{antoun-etal-2021-araelectra,
|
132 |
+
title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
|
133 |
+
author = "Antoun, Wissam and
|
134 |
+
Baly, Fady and
|
135 |
+
Hajj, Hazem",
|
136 |
+
booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
|
137 |
+
month = apr,
|
138 |
+
year = "2021",
|
139 |
+
address = "Kyiv, Ukraine (Virtual)",
|
140 |
+
publisher = "Association for Computational Linguistics",
|
141 |
+
url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
|
142 |
+
pages = "191--195",
|
143 |
+
}
|
144 |
+
```
|
145 |
+
|
146 |
+
|
147 |
+
# Acknowledgments
|
148 |
+
Thanks to TensorFlow Research Cloud (TFRC) for the free access to Cloud TPUs, couldn't have done it without this program, and to the [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) Members for the continous support. Also thanks to [Yakshof](https://www.yakshof.com/#/) and Assafir for data and storage access. Another thanks for Habib Rahal (https://www.behance.net/rahalhabib), for putting a face to AraBERT.
|
149 |
+
|
150 |
+
# Contacts
|
151 |
+
**Wissam Antoun**: [Linkedin](https://www.linkedin.com/in/wissam-antoun-622142b4/) | [Twitter](https://twitter.com/wissam_antoun) | [Github](https://github.com/WissamAntoun) | wfa07 (AT) mail (DOT) aub (DOT) edu | wissam.antoun (AT) gmail (DOT) com
|
152 |
+
|
153 |
+
**Fady Baly**: [Linkedin](https://www.linkedin.com/in/fadybaly/) | [Twitter](https://twitter.com/fadybaly) | [Github](https://github.com/fadybaly) | fgb06 (AT) mail (DOT) aub (DOT) edu | baly.fady (AT) gmail (DOT) com
|
154 |
+
|
155 |
+
"""
|
156 |
+
)
|
backend/modeling_gpt2.py
ADDED
@@ -0,0 +1,1599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
PyTorch OpenAI GPT-2 model.
|
19 |
+
Adapted from https://github.com/huggingface/transformers/blob/v4.0.1/src/transformers/models/gpt2/modeling_gpt2.py
|
20 |
+
and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/modeling_gpt2.py
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import os
|
26 |
+
from dataclasses import dataclass
|
27 |
+
from typing import List, Optional, Tuple
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
32 |
+
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
|
33 |
+
from transformers.activations import ACT2FN
|
34 |
+
from transformers.file_utils import (
|
35 |
+
ModelOutput,
|
36 |
+
add_code_sample_docstrings,
|
37 |
+
add_start_docstrings,
|
38 |
+
add_start_docstrings_to_model_forward,
|
39 |
+
replace_return_docstrings,
|
40 |
+
)
|
41 |
+
from transformers.modeling_outputs import (
|
42 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
43 |
+
CausalLMOutputWithCrossAttentions,
|
44 |
+
SequenceClassifierOutputWithPast,
|
45 |
+
TokenClassifierOutput,
|
46 |
+
)
|
47 |
+
from transformers.modeling_utils import (
|
48 |
+
Conv1D,
|
49 |
+
PreTrainedModel,
|
50 |
+
SequenceSummary,
|
51 |
+
find_pruneable_heads_and_indices,
|
52 |
+
prune_conv1d_layer,
|
53 |
+
)
|
54 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
55 |
+
|
56 |
+
# THe Difference from Transformers is code under _USE_GROVER
|
57 |
+
_USE_GROVER = True
|
58 |
+
|
59 |
+
logger = logging.getLogger(__name__)
|
60 |
+
|
61 |
+
_CONFIG_FOR_DOC = "GPT2Config"
|
62 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
63 |
+
|
64 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
65 |
+
"gpt2",
|
66 |
+
"gpt2-medium",
|
67 |
+
"gpt2-large",
|
68 |
+
"gpt2-xl",
|
69 |
+
"distilgpt2",
|
70 |
+
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
71 |
+
]
|
72 |
+
|
73 |
+
logger.setLevel(logging.INFO)
|
74 |
+
console = logging.StreamHandler()
|
75 |
+
console.setLevel(logging.INFO)
|
76 |
+
logger.addHandler(console)
|
77 |
+
|
78 |
+
_GPT2_ML_TF_TO_TORCH = {
|
79 |
+
"LayerNorm_embed_norm": "emb_norm",
|
80 |
+
"pos_embed": "wpe.weight",
|
81 |
+
"word_embed": "wte.weight",
|
82 |
+
"layer": "h",
|
83 |
+
# Most importently This two layer norm must be put on the same position as gpt2-ml
|
84 |
+
# or generated data is bad, just repeat the last token
|
85 |
+
"LayerNorm_mlp_ln0": "ln_1",
|
86 |
+
"LayerNorm_mlp_ln1": "ln_2",
|
87 |
+
"intermediate": "mlp.c_fc",
|
88 |
+
"output": "mlp.c_proj",
|
89 |
+
"query_layer": "attn.c_attn",
|
90 |
+
"key_layer": "attn.c_attn",
|
91 |
+
"value_layer": "attn.c_attn",
|
92 |
+
"context_projection_layer": "attn.c_proj",
|
93 |
+
"gamma": "weight",
|
94 |
+
"kernel": "weight",
|
95 |
+
"beta": "bias",
|
96 |
+
"bias": "bias",
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
def convert_gpt2_checkpoint_to_pytorch(
|
101 |
+
gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path
|
102 |
+
):
|
103 |
+
# Construct model
|
104 |
+
if gpt2_config_file == "":
|
105 |
+
config = GPT2Config()
|
106 |
+
else:
|
107 |
+
config = GPT2Config.from_json_file(gpt2_config_file)
|
108 |
+
model = GPT2Model(config)
|
109 |
+
|
110 |
+
# Load weights from numpy
|
111 |
+
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
|
112 |
+
|
113 |
+
# Save pytorch-model
|
114 |
+
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
115 |
+
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
116 |
+
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
117 |
+
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
118 |
+
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
119 |
+
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
120 |
+
f.write(config.to_json_string())
|
121 |
+
|
122 |
+
|
123 |
+
# XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
|
124 |
+
# https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
|
125 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
126 |
+
"""Load tf checkpoints in a pytorch model"""
|
127 |
+
try:
|
128 |
+
import re
|
129 |
+
|
130 |
+
import tensorflow as tf
|
131 |
+
except ImportError:
|
132 |
+
logger.error(
|
133 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
134 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
135 |
+
)
|
136 |
+
raise
|
137 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
138 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
139 |
+
# Load weights from TF model
|
140 |
+
init_vars = tf.train.list_variables(tf_path)
|
141 |
+
names = []
|
142 |
+
arrays = []
|
143 |
+
for name, shape in init_vars:
|
144 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
145 |
+
array = tf.train.load_variable(tf_path, name)
|
146 |
+
names.append(name)
|
147 |
+
arrays.append(array.squeeze())
|
148 |
+
|
149 |
+
import copy
|
150 |
+
|
151 |
+
orig_model = copy.deepcopy(model)
|
152 |
+
|
153 |
+
for name, array in zip(names, arrays):
|
154 |
+
name = name[6:] # skip "model/"
|
155 |
+
name = name.split("/")
|
156 |
+
pointer = model
|
157 |
+
|
158 |
+
attn_layer = ""
|
159 |
+
for m_name in name:
|
160 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
161 |
+
scope_names = re.split(r"(\d+)", m_name)
|
162 |
+
else:
|
163 |
+
scope_names = [m_name]
|
164 |
+
sname = scope_names[0]
|
165 |
+
|
166 |
+
if sname == "" or sname == "embeddings":
|
167 |
+
continue
|
168 |
+
elif sname not in _GPT2_ML_TF_TO_TORCH:
|
169 |
+
print("=========================================================")
|
170 |
+
logger.info("Skip var name {}".format(scope_names))
|
171 |
+
pointer = None
|
172 |
+
break
|
173 |
+
else:
|
174 |
+
tname = _GPT2_ML_TF_TO_TORCH[sname]
|
175 |
+
if "." in tname:
|
176 |
+
parent, child = tname.split(".")
|
177 |
+
pointer = getattr(pointer, parent)
|
178 |
+
pointer = getattr(pointer, child)
|
179 |
+
else:
|
180 |
+
pointer = getattr(pointer, tname)
|
181 |
+
|
182 |
+
if tname == "attn.c_attn":
|
183 |
+
attn_layer = sname
|
184 |
+
|
185 |
+
if len(scope_names) >= 2:
|
186 |
+
num = int(scope_names[1])
|
187 |
+
pointer = pointer[num]
|
188 |
+
|
189 |
+
if pointer is None:
|
190 |
+
continue
|
191 |
+
if attn_layer == "":
|
192 |
+
try:
|
193 |
+
assert pointer.shape == array.shape
|
194 |
+
except AssertionError as e:
|
195 |
+
e.args += (pointer.shape, array.shape)
|
196 |
+
raise
|
197 |
+
logger.info(
|
198 |
+
"Initialize PyTorch weight {}, {}, {}".format(
|
199 |
+
name, array.mean(), pointer.mean()
|
200 |
+
)
|
201 |
+
)
|
202 |
+
if attn_layer == "":
|
203 |
+
pointer.data = torch.from_numpy(array)
|
204 |
+
else:
|
205 |
+
shape = pointer.shape
|
206 |
+
d = torch.from_numpy(array)
|
207 |
+
is_bias = len(shape) == 1
|
208 |
+
end = int(shape[0 if is_bias else 1] / 3)
|
209 |
+
m = dict(
|
210 |
+
query_layer=0,
|
211 |
+
key_layer=end,
|
212 |
+
value_layer=end * 2,
|
213 |
+
)
|
214 |
+
start = m[attn_layer]
|
215 |
+
end = start + end
|
216 |
+
if is_bias:
|
217 |
+
pointer.data[start:end] = d
|
218 |
+
else:
|
219 |
+
pointer.data[:, start:end] = d
|
220 |
+
logger.info(
|
221 |
+
"Initialize PyTorch weight {}, {}, {}".format(
|
222 |
+
name, array.mean(), pointer.mean()
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
for name, params in orig_model.named_parameters():
|
227 |
+
for n, p in model.named_parameters():
|
228 |
+
if name == n:
|
229 |
+
if params.equal(p):
|
230 |
+
print("--------------------------")
|
231 |
+
print(" %s not changed!" % n)
|
232 |
+
return model
|
233 |
+
|
234 |
+
|
235 |
+
class Attention(nn.Module):
|
236 |
+
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
|
237 |
+
super().__init__()
|
238 |
+
|
239 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
240 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
241 |
+
assert n_state % config.n_head == 0
|
242 |
+
self.register_buffer(
|
243 |
+
"bias",
|
244 |
+
torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(
|
245 |
+
1, 1, n_ctx, n_ctx
|
246 |
+
),
|
247 |
+
)
|
248 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
249 |
+
self.n_head = config.n_head
|
250 |
+
self.split_size = n_state
|
251 |
+
self.scale = scale
|
252 |
+
self.is_cross_attention = is_cross_attention
|
253 |
+
if self.is_cross_attention:
|
254 |
+
self.c_attn = Conv1D(2 * n_state, nx)
|
255 |
+
self.q_attn = Conv1D(n_state, nx)
|
256 |
+
else:
|
257 |
+
self.c_attn = Conv1D(3 * n_state, nx)
|
258 |
+
self.c_proj = Conv1D(n_state, nx)
|
259 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
260 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
261 |
+
self.pruned_heads = set()
|
262 |
+
|
263 |
+
def prune_heads(self, heads):
|
264 |
+
if len(heads) == 0:
|
265 |
+
return
|
266 |
+
heads, index = find_pruneable_heads_and_indices(
|
267 |
+
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
268 |
+
)
|
269 |
+
index_attn = torch.cat(
|
270 |
+
[index, index + self.split_size, index + (2 * self.split_size)]
|
271 |
+
)
|
272 |
+
|
273 |
+
# Prune conv1d layers
|
274 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
275 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
276 |
+
|
277 |
+
# Update hyper params
|
278 |
+
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
279 |
+
self.n_head = self.n_head - len(heads)
|
280 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
281 |
+
|
282 |
+
def _attn(
|
283 |
+
self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False
|
284 |
+
):
|
285 |
+
w = torch.matmul(q, k)
|
286 |
+
if self.scale:
|
287 |
+
w = w / (float(v.size(-1)) ** 0.5)
|
288 |
+
nd, ns = w.size(-2), w.size(-1)
|
289 |
+
|
290 |
+
if not self.is_cross_attention:
|
291 |
+
# if only "normal" attention layer implements causal mask
|
292 |
+
mask = self.bias[:, :, ns - nd : ns, :ns]
|
293 |
+
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
294 |
+
|
295 |
+
if attention_mask is not None:
|
296 |
+
# Apply the attention mask
|
297 |
+
w = w + attention_mask
|
298 |
+
|
299 |
+
w = nn.Softmax(dim=-1)(w)
|
300 |
+
w = self.attn_dropout(w)
|
301 |
+
|
302 |
+
# Mask heads if we want to
|
303 |
+
if head_mask is not None:
|
304 |
+
w = w * head_mask
|
305 |
+
|
306 |
+
outputs = [torch.matmul(w, v)]
|
307 |
+
if output_attentions:
|
308 |
+
outputs.append(w)
|
309 |
+
return outputs
|
310 |
+
|
311 |
+
def merge_heads(self, x):
|
312 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
313 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
314 |
+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
315 |
+
|
316 |
+
def split_heads(self, x, k=False):
|
317 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
318 |
+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
319 |
+
if k:
|
320 |
+
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
321 |
+
else:
|
322 |
+
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
323 |
+
|
324 |
+
def forward(
|
325 |
+
self,
|
326 |
+
hidden_states,
|
327 |
+
layer_past=None,
|
328 |
+
attention_mask=None,
|
329 |
+
head_mask=None,
|
330 |
+
encoder_hidden_states=None,
|
331 |
+
encoder_attention_mask=None,
|
332 |
+
use_cache=False,
|
333 |
+
output_attentions=False,
|
334 |
+
):
|
335 |
+
if encoder_hidden_states is not None:
|
336 |
+
assert hasattr(
|
337 |
+
self, "q_attn"
|
338 |
+
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
|
339 |
+
query = self.q_attn(hidden_states)
|
340 |
+
key, value = self.c_attn(encoder_hidden_states).split(
|
341 |
+
self.split_size, dim=2
|
342 |
+
)
|
343 |
+
attention_mask = encoder_attention_mask
|
344 |
+
else:
|
345 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
346 |
+
|
347 |
+
query = self.split_heads(query)
|
348 |
+
key = self.split_heads(key, k=True)
|
349 |
+
value = self.split_heads(value)
|
350 |
+
if layer_past is not None:
|
351 |
+
past_key, past_value = (
|
352 |
+
layer_past[0].transpose(-2, -1),
|
353 |
+
layer_past[1],
|
354 |
+
) # transpose back cf below
|
355 |
+
key = torch.cat((past_key, key), dim=-1)
|
356 |
+
value = torch.cat((past_value, value), dim=-2)
|
357 |
+
|
358 |
+
if use_cache is True:
|
359 |
+
present = torch.stack(
|
360 |
+
(key.transpose(-2, -1), value)
|
361 |
+
) # transpose to have same shapes for stacking
|
362 |
+
else:
|
363 |
+
present = (None,)
|
364 |
+
|
365 |
+
attn_outputs = self._attn(
|
366 |
+
query, key, value, attention_mask, head_mask, output_attentions
|
367 |
+
)
|
368 |
+
a = attn_outputs[0]
|
369 |
+
|
370 |
+
a = self.merge_heads(a)
|
371 |
+
a = self.c_proj(a)
|
372 |
+
a = self.resid_dropout(a)
|
373 |
+
|
374 |
+
outputs = [a, present] + attn_outputs[1:]
|
375 |
+
return outputs # a, present, (attentions)
|
376 |
+
|
377 |
+
|
378 |
+
class MLP(nn.Module):
|
379 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
380 |
+
super().__init__()
|
381 |
+
nx = config.n_embd
|
382 |
+
self.c_fc = Conv1D(n_state, nx)
|
383 |
+
self.c_proj = Conv1D(nx, n_state)
|
384 |
+
self.act = ACT2FN[config.activation_function]
|
385 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
386 |
+
|
387 |
+
def forward(self, x):
|
388 |
+
h = self.act(self.c_fc(x))
|
389 |
+
h2 = self.c_proj(h)
|
390 |
+
return self.dropout(h2)
|
391 |
+
|
392 |
+
|
393 |
+
class Block(nn.Module):
|
394 |
+
def __init__(self, n_ctx, config, scale=False):
|
395 |
+
super().__init__()
|
396 |
+
hidden_size = config.n_embd
|
397 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
398 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
399 |
+
self.attn = Attention(hidden_size, n_ctx, config, scale)
|
400 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
401 |
+
if config.add_cross_attention:
|
402 |
+
self.crossattention = Attention(
|
403 |
+
hidden_size, n_ctx, config, scale, is_cross_attention=True
|
404 |
+
)
|
405 |
+
self.ln_cross_attn = nn.LayerNorm(
|
406 |
+
hidden_size, eps=config.layer_norm_epsilon
|
407 |
+
)
|
408 |
+
self.mlp = MLP(inner_dim, config)
|
409 |
+
|
410 |
+
def forward(
|
411 |
+
self,
|
412 |
+
hidden_states,
|
413 |
+
layer_past=None,
|
414 |
+
attention_mask=None,
|
415 |
+
head_mask=None,
|
416 |
+
encoder_hidden_states=None,
|
417 |
+
encoder_attention_mask=None,
|
418 |
+
use_cache=False,
|
419 |
+
output_attentions=False,
|
420 |
+
):
|
421 |
+
attn_outputs = self.attn(
|
422 |
+
hidden_states,
|
423 |
+
layer_past=layer_past,
|
424 |
+
attention_mask=attention_mask,
|
425 |
+
head_mask=head_mask,
|
426 |
+
use_cache=use_cache,
|
427 |
+
output_attentions=output_attentions,
|
428 |
+
)
|
429 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
430 |
+
outputs = attn_outputs[1:]
|
431 |
+
# residual connection
|
432 |
+
hidden_states = attn_output + hidden_states
|
433 |
+
|
434 |
+
if encoder_hidden_states is not None:
|
435 |
+
# add one self-attention block for cross-attention
|
436 |
+
assert hasattr(
|
437 |
+
self, "crossattention"
|
438 |
+
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
439 |
+
cross_attn_outputs = self.crossattention(
|
440 |
+
self.ln_cross_attn(hidden_states),
|
441 |
+
attention_mask=attention_mask,
|
442 |
+
head_mask=head_mask,
|
443 |
+
encoder_hidden_states=encoder_hidden_states,
|
444 |
+
encoder_attention_mask=encoder_attention_mask,
|
445 |
+
output_attentions=output_attentions,
|
446 |
+
)
|
447 |
+
attn_output = cross_attn_outputs[0]
|
448 |
+
# residual connection
|
449 |
+
hidden_states = hidden_states + attn_output
|
450 |
+
outputs = (
|
451 |
+
outputs + cross_attn_outputs[2:]
|
452 |
+
) # add cross attentions if we output attention weights
|
453 |
+
|
454 |
+
feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
|
455 |
+
# residual connection
|
456 |
+
hidden_states = hidden_states + feed_forward_hidden_states
|
457 |
+
|
458 |
+
hidden_states = self.ln_2(hidden_states)
|
459 |
+
|
460 |
+
outputs = [hidden_states] + outputs
|
461 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
462 |
+
|
463 |
+
|
464 |
+
class GPT2PreTrainedModel(PreTrainedModel):
|
465 |
+
"""
|
466 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
467 |
+
models.
|
468 |
+
"""
|
469 |
+
|
470 |
+
config_class = GPT2Config
|
471 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
472 |
+
base_model_prefix = "transformer"
|
473 |
+
is_parallelizable = True
|
474 |
+
|
475 |
+
def __init__(self, *inputs, **kwargs):
|
476 |
+
super().__init__(*inputs, **kwargs)
|
477 |
+
|
478 |
+
def _init_weights(self, module):
|
479 |
+
"""Initialize the weights."""
|
480 |
+
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
481 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
482 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
483 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
484 |
+
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
485 |
+
module.bias.data.zero_()
|
486 |
+
elif isinstance(module, nn.LayerNorm):
|
487 |
+
module.bias.data.zero_()
|
488 |
+
module.weight.data.fill_(1.0)
|
489 |
+
|
490 |
+
|
491 |
+
@dataclass
|
492 |
+
class GPT2DoubleHeadsModelOutput(ModelOutput):
|
493 |
+
"""
|
494 |
+
Base class for outputs of models predicting if two sentences are consecutive or not.
|
495 |
+
|
496 |
+
Args:
|
497 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
498 |
+
Language modeling loss.
|
499 |
+
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
500 |
+
Multiple choice classification loss.
|
501 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
502 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
503 |
+
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
504 |
+
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
505 |
+
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
506 |
+
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
507 |
+
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
508 |
+
|
509 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
510 |
+
:obj:`past_key_values` input) to speed up sequential decoding.
|
511 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
512 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
513 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
514 |
+
|
515 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
516 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
517 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
518 |
+
sequence_length, sequence_length)`.
|
519 |
+
|
520 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
521 |
+
heads.
|
522 |
+
"""
|
523 |
+
|
524 |
+
loss: Optional[torch.FloatTensor] = None
|
525 |
+
mc_loss: Optional[torch.FloatTensor] = None
|
526 |
+
logits: torch.FloatTensor = None
|
527 |
+
mc_logits: torch.FloatTensor = None
|
528 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
529 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
530 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
531 |
+
|
532 |
+
|
533 |
+
GPT2_START_DOCSTRING = r"""
|
534 |
+
|
535 |
+
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
536 |
+
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
537 |
+
pruning heads etc.)
|
538 |
+
|
539 |
+
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
540 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
541 |
+
general usage and behavior.
|
542 |
+
|
543 |
+
Parameters:
|
544 |
+
config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
545 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
546 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
547 |
+
weights.
|
548 |
+
"""
|
549 |
+
|
550 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
551 |
+
Args:
|
552 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
553 |
+
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
554 |
+
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
555 |
+
sequence tokens in the vocabulary.
|
556 |
+
|
557 |
+
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
558 |
+
passed as ``input_ids``.
|
559 |
+
|
560 |
+
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
|
561 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
562 |
+
details.
|
563 |
+
|
564 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
565 |
+
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
566 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
567 |
+
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
568 |
+
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
569 |
+
computed.
|
570 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
571 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
572 |
+
|
573 |
+
- 1 for tokens that are **not masked**,
|
574 |
+
- 0 for tokens that are **masked**.
|
575 |
+
|
576 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
577 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
|
578 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
579 |
+
1]``:
|
580 |
+
|
581 |
+
- 0 corresponds to a `sentence A` token,
|
582 |
+
- 1 corresponds to a `sentence B` token.
|
583 |
+
|
584 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
585 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
586 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
587 |
+
config.max_position_embeddings - 1]``.
|
588 |
+
|
589 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
590 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
591 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
592 |
+
|
593 |
+
- 1 indicates the head is **not masked**,
|
594 |
+
- 0 indicates the head is **masked**.
|
595 |
+
|
596 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
597 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
598 |
+
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
599 |
+
vectors than the model's internal embedding lookup matrix.
|
600 |
+
|
601 |
+
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
|
602 |
+
:obj:`past_key_values`).
|
603 |
+
use_cache (:obj:`bool`, `optional`):
|
604 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
605 |
+
decoding (see :obj:`past_key_values`).
|
606 |
+
output_attentions (:obj:`bool`, `optional`):
|
607 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
608 |
+
tensors for more detail.
|
609 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
610 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
611 |
+
more detail.
|
612 |
+
return_dict (:obj:`bool`, `optional`):
|
613 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
614 |
+
"""
|
615 |
+
|
616 |
+
PARALLELIZE_DOCSTRING = r"""
|
617 |
+
This is an experimental feature and is a subject to change at a moment's notice.
|
618 |
+
|
619 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
620 |
+
it will evenly distribute blocks across all devices.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
|
624 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
625 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
626 |
+
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
627 |
+
following number of attention modules:
|
628 |
+
|
629 |
+
- gpt2: 12
|
630 |
+
- gpt2-medium: 24
|
631 |
+
- gpt2-large: 36
|
632 |
+
- gpt2-xl: 48
|
633 |
+
|
634 |
+
Example::
|
635 |
+
|
636 |
+
# Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
637 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
|
638 |
+
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
639 |
+
|
640 |
+
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
641 |
+
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
642 |
+
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
|
643 |
+
model.parallelize(device_map)
|
644 |
+
"""
|
645 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
646 |
+
Moves the model to cpu from a model parallel state.
|
647 |
+
|
648 |
+
Example::
|
649 |
+
|
650 |
+
# On a 4 GPU machine with gpt2-large:
|
651 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
|
652 |
+
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
|
653 |
+
|
654 |
+
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
655 |
+
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
656 |
+
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
|
657 |
+
model.parallelize(device_map) # Splits the model across several devices
|
658 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
659 |
+
"""
|
660 |
+
|
661 |
+
|
662 |
+
@add_start_docstrings(
|
663 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
664 |
+
GPT2_START_DOCSTRING,
|
665 |
+
)
|
666 |
+
class GPT2Model(GPT2PreTrainedModel):
|
667 |
+
def __init__(self, config):
|
668 |
+
super().__init__(config)
|
669 |
+
|
670 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
671 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
672 |
+
if _USE_GROVER:
|
673 |
+
self.emb_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
674 |
+
|
675 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
676 |
+
self.h = nn.ModuleList(
|
677 |
+
[Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]
|
678 |
+
)
|
679 |
+
if not _USE_GROVER:
|
680 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
681 |
+
|
682 |
+
self.init_weights()
|
683 |
+
|
684 |
+
# Model parallel
|
685 |
+
self.model_parallel = False
|
686 |
+
self.device_map = None
|
687 |
+
|
688 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
689 |
+
def parallelize(self, device_map=None):
|
690 |
+
# Check validity of device_map
|
691 |
+
self.device_map = (
|
692 |
+
get_device_map(len(self.h), range(torch.cuda.device_count()))
|
693 |
+
if device_map is None
|
694 |
+
else device_map
|
695 |
+
)
|
696 |
+
assert_device_map(self.device_map, len(self.h))
|
697 |
+
self.model_parallel = True
|
698 |
+
self.first_device = (
|
699 |
+
"cpu"
|
700 |
+
if "cpu" in self.device_map.keys()
|
701 |
+
else "cuda:" + str(min(self.device_map.keys()))
|
702 |
+
)
|
703 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
704 |
+
self.wte = self.wte.to(self.first_device)
|
705 |
+
self.wpe = self.wpe.to(self.first_device)
|
706 |
+
# Load onto devices
|
707 |
+
for k, v in self.device_map.items():
|
708 |
+
for block in v:
|
709 |
+
cuda_device = "cuda:" + str(k)
|
710 |
+
self.h[block] = self.h[block].to(cuda_device)
|
711 |
+
# ln_f to last
|
712 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
713 |
+
|
714 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
715 |
+
def deparallelize(self):
|
716 |
+
self.model_parallel = False
|
717 |
+
self.device_map = None
|
718 |
+
self.first_device = "cpu"
|
719 |
+
self.last_device = "cpu"
|
720 |
+
self.wte = self.wte.to("cpu")
|
721 |
+
self.wpe = self.wpe.to("cpu")
|
722 |
+
for index in range(len(self.h)):
|
723 |
+
self.h[index] = self.h[index].to("cpu")
|
724 |
+
self.ln_f = self.ln_f.to("cpu")
|
725 |
+
torch.cuda.empty_cache()
|
726 |
+
|
727 |
+
def get_input_embeddings(self):
|
728 |
+
return self.wte
|
729 |
+
|
730 |
+
def set_input_embeddings(self, new_embeddings):
|
731 |
+
self.wte = new_embeddings
|
732 |
+
|
733 |
+
def _prune_heads(self, heads_to_prune):
|
734 |
+
"""
|
735 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
736 |
+
"""
|
737 |
+
for layer, heads in heads_to_prune.items():
|
738 |
+
self.h[layer].attn.prune_heads(heads)
|
739 |
+
|
740 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
741 |
+
@add_code_sample_docstrings(
|
742 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
743 |
+
checkpoint="gpt2",
|
744 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
745 |
+
config_class=_CONFIG_FOR_DOC,
|
746 |
+
)
|
747 |
+
def forward(
|
748 |
+
self,
|
749 |
+
input_ids=None,
|
750 |
+
past_key_values=None,
|
751 |
+
attention_mask=None,
|
752 |
+
token_type_ids=None,
|
753 |
+
position_ids=None,
|
754 |
+
head_mask=None,
|
755 |
+
inputs_embeds=None,
|
756 |
+
encoder_hidden_states=None,
|
757 |
+
encoder_attention_mask=None,
|
758 |
+
use_cache=None,
|
759 |
+
output_attentions=None,
|
760 |
+
output_hidden_states=None,
|
761 |
+
return_dict=None,
|
762 |
+
):
|
763 |
+
output_attentions = (
|
764 |
+
output_attentions
|
765 |
+
if output_attentions is not None
|
766 |
+
else self.config.output_attentions
|
767 |
+
)
|
768 |
+
output_hidden_states = (
|
769 |
+
output_hidden_states
|
770 |
+
if output_hidden_states is not None
|
771 |
+
else self.config.output_hidden_states
|
772 |
+
)
|
773 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
774 |
+
return_dict = (
|
775 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
776 |
+
)
|
777 |
+
|
778 |
+
if input_ids is not None and inputs_embeds is not None:
|
779 |
+
raise ValueError(
|
780 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
781 |
+
)
|
782 |
+
elif input_ids is not None:
|
783 |
+
input_shape = input_ids.size()
|
784 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
785 |
+
batch_size = input_ids.shape[0]
|
786 |
+
elif inputs_embeds is not None:
|
787 |
+
input_shape = inputs_embeds.size()[:-1]
|
788 |
+
batch_size = inputs_embeds.shape[0]
|
789 |
+
else:
|
790 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
791 |
+
|
792 |
+
if token_type_ids is not None:
|
793 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
794 |
+
if position_ids is not None:
|
795 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
796 |
+
|
797 |
+
if past_key_values is None:
|
798 |
+
past_length = 0
|
799 |
+
past_key_values = [None] * len(self.h)
|
800 |
+
else:
|
801 |
+
past_length = past_key_values[0][0].size(-2)
|
802 |
+
if position_ids is None:
|
803 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
804 |
+
position_ids = torch.arange(
|
805 |
+
past_length,
|
806 |
+
input_shape[-1] + past_length,
|
807 |
+
dtype=torch.long,
|
808 |
+
device=device,
|
809 |
+
)
|
810 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
811 |
+
|
812 |
+
# Attention mask.
|
813 |
+
if attention_mask is not None:
|
814 |
+
if batch_size <= 0:
|
815 |
+
raise ValueError("batch_size has to be defined and > 0")
|
816 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
817 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
818 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
819 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
820 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
821 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
822 |
+
attention_mask = attention_mask[:, None, None, :]
|
823 |
+
|
824 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
825 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
826 |
+
# positions we want to attend and -10000.0 for masked positions.
|
827 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
828 |
+
# effectively the same as removing these entirely.
|
829 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
830 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
831 |
+
|
832 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
833 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
834 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
835 |
+
(
|
836 |
+
encoder_batch_size,
|
837 |
+
encoder_sequence_length,
|
838 |
+
_,
|
839 |
+
) = encoder_hidden_states.size()
|
840 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
841 |
+
if encoder_attention_mask is None:
|
842 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
843 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
844 |
+
else:
|
845 |
+
encoder_attention_mask = None
|
846 |
+
|
847 |
+
# Prepare head mask if needed
|
848 |
+
# 1.0 in head_mask indicate we keep the head
|
849 |
+
# attention_probs has shape bsz x n_heads x N x N
|
850 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
851 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
852 |
+
|
853 |
+
if inputs_embeds is None:
|
854 |
+
inputs_embeds = self.wte(input_ids)
|
855 |
+
position_embeds = self.wpe(position_ids)
|
856 |
+
hidden_states = inputs_embeds + position_embeds
|
857 |
+
|
858 |
+
if token_type_ids is not None:
|
859 |
+
token_type_embeds = self.wte(token_type_ids)
|
860 |
+
hidden_states = hidden_states + token_type_embeds
|
861 |
+
|
862 |
+
hidden_states = self.drop(hidden_states)
|
863 |
+
if _USE_GROVER:
|
864 |
+
hidden_states = self.emb_norm(hidden_states)
|
865 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
866 |
+
|
867 |
+
presents = () if use_cache else None
|
868 |
+
all_self_attentions = () if output_attentions else None
|
869 |
+
all_cross_attentions = (
|
870 |
+
() if output_attentions and self.config.add_cross_attention else None
|
871 |
+
)
|
872 |
+
all_hidden_states = () if output_hidden_states else None
|
873 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
874 |
+
|
875 |
+
# Model parallel
|
876 |
+
if self.model_parallel:
|
877 |
+
torch.cuda.set_device(hidden_states.device)
|
878 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
879 |
+
if layer_past is not None:
|
880 |
+
layer_past = tuple(
|
881 |
+
past_state.to(hidden_states.device) for past_state in layer_past
|
882 |
+
)
|
883 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
884 |
+
if attention_mask is not None:
|
885 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
886 |
+
if isinstance(head_mask, torch.Tensor):
|
887 |
+
head_mask = head_mask.to(hidden_states.device)
|
888 |
+
|
889 |
+
if output_hidden_states:
|
890 |
+
all_hidden_states = all_hidden_states + (
|
891 |
+
hidden_states.view(*output_shape),
|
892 |
+
)
|
893 |
+
|
894 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
895 |
+
|
896 |
+
def create_custom_forward(module):
|
897 |
+
def custom_forward(*inputs):
|
898 |
+
# checkpointing only works with tuple returns, not with lists
|
899 |
+
return tuple(
|
900 |
+
output
|
901 |
+
for output in module(*inputs, use_cache, output_attentions)
|
902 |
+
)
|
903 |
+
|
904 |
+
return custom_forward
|
905 |
+
|
906 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
907 |
+
create_custom_forward(block),
|
908 |
+
hidden_states,
|
909 |
+
layer_past,
|
910 |
+
attention_mask,
|
911 |
+
head_mask[i],
|
912 |
+
encoder_hidden_states,
|
913 |
+
encoder_attention_mask,
|
914 |
+
)
|
915 |
+
else:
|
916 |
+
outputs = block(
|
917 |
+
hidden_states,
|
918 |
+
layer_past=layer_past,
|
919 |
+
attention_mask=attention_mask,
|
920 |
+
head_mask=head_mask[i],
|
921 |
+
encoder_hidden_states=encoder_hidden_states,
|
922 |
+
encoder_attention_mask=encoder_attention_mask,
|
923 |
+
use_cache=use_cache,
|
924 |
+
output_attentions=output_attentions,
|
925 |
+
)
|
926 |
+
|
927 |
+
hidden_states, present = outputs[:2]
|
928 |
+
if use_cache is True:
|
929 |
+
presents = presents + (present,)
|
930 |
+
|
931 |
+
if output_attentions:
|
932 |
+
all_self_attentions = all_self_attentions + (
|
933 |
+
outputs[2 if use_cache else 1],
|
934 |
+
)
|
935 |
+
if self.config.add_cross_attention:
|
936 |
+
all_cross_attentions = all_cross_attentions + (
|
937 |
+
outputs[3 if use_cache else 2],
|
938 |
+
)
|
939 |
+
|
940 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
941 |
+
if self.model_parallel:
|
942 |
+
for k, v in self.device_map.items():
|
943 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
944 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
945 |
+
|
946 |
+
if not _USE_GROVER:
|
947 |
+
hidden_states = self.ln_f(hidden_states)
|
948 |
+
|
949 |
+
hidden_states = hidden_states.view(*output_shape)
|
950 |
+
# Add last hidden state
|
951 |
+
if output_hidden_states:
|
952 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
953 |
+
|
954 |
+
if not return_dict:
|
955 |
+
return tuple(
|
956 |
+
v
|
957 |
+
for v in [
|
958 |
+
hidden_states,
|
959 |
+
presents,
|
960 |
+
all_hidden_states,
|
961 |
+
all_self_attentions,
|
962 |
+
all_cross_attentions,
|
963 |
+
]
|
964 |
+
if v is not None
|
965 |
+
)
|
966 |
+
|
967 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
968 |
+
last_hidden_state=hidden_states,
|
969 |
+
past_key_values=presents,
|
970 |
+
hidden_states=all_hidden_states,
|
971 |
+
attentions=all_self_attentions,
|
972 |
+
cross_attentions=all_cross_attentions,
|
973 |
+
)
|
974 |
+
|
975 |
+
|
976 |
+
@add_start_docstrings(
|
977 |
+
"""
|
978 |
+
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
979 |
+
embeddings).
|
980 |
+
""",
|
981 |
+
GPT2_START_DOCSTRING,
|
982 |
+
)
|
983 |
+
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
984 |
+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
985 |
+
|
986 |
+
def __init__(self, config):
|
987 |
+
super().__init__(config)
|
988 |
+
self.transformer = GPT2Model(config)
|
989 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
990 |
+
|
991 |
+
self.init_weights()
|
992 |
+
|
993 |
+
# Model parallel
|
994 |
+
self.model_parallel = False
|
995 |
+
self.device_map = None
|
996 |
+
|
997 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
998 |
+
def parallelize(self, device_map=None):
|
999 |
+
self.device_map = (
|
1000 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1001 |
+
if device_map is None
|
1002 |
+
else device_map
|
1003 |
+
)
|
1004 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
1005 |
+
self.transformer.parallelize(self.device_map)
|
1006 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1007 |
+
self.model_parallel = True
|
1008 |
+
|
1009 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1010 |
+
def deparallelize(self):
|
1011 |
+
self.transformer.deparallelize()
|
1012 |
+
self.transformer = self.transformer.to("cpu")
|
1013 |
+
self.lm_head = self.lm_head.to("cpu")
|
1014 |
+
self.model_parallel = False
|
1015 |
+
torch.cuda.empty_cache()
|
1016 |
+
|
1017 |
+
def get_output_embeddings(self):
|
1018 |
+
return self.lm_head
|
1019 |
+
|
1020 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1021 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
1022 |
+
# only last token for inputs_ids if past is defined in kwargs
|
1023 |
+
if past:
|
1024 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1025 |
+
if token_type_ids is not None:
|
1026 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
1027 |
+
|
1028 |
+
attention_mask = kwargs.get("attention_mask", None)
|
1029 |
+
position_ids = kwargs.get("position_ids", None)
|
1030 |
+
|
1031 |
+
if attention_mask is not None and position_ids is None:
|
1032 |
+
# create position_ids on the fly for batch generation
|
1033 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1034 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1035 |
+
if past:
|
1036 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1037 |
+
else:
|
1038 |
+
position_ids = None
|
1039 |
+
return {
|
1040 |
+
"input_ids": input_ids,
|
1041 |
+
"past_key_values": past,
|
1042 |
+
"use_cache": kwargs.get("use_cache"),
|
1043 |
+
"position_ids": position_ids,
|
1044 |
+
"attention_mask": attention_mask,
|
1045 |
+
"token_type_ids": token_type_ids,
|
1046 |
+
}
|
1047 |
+
|
1048 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1049 |
+
@add_code_sample_docstrings(
|
1050 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1051 |
+
checkpoint="gpt2",
|
1052 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1053 |
+
config_class=_CONFIG_FOR_DOC,
|
1054 |
+
)
|
1055 |
+
def forward(
|
1056 |
+
self,
|
1057 |
+
input_ids=None,
|
1058 |
+
past_key_values=None,
|
1059 |
+
attention_mask=None,
|
1060 |
+
token_type_ids=None,
|
1061 |
+
position_ids=None,
|
1062 |
+
head_mask=None,
|
1063 |
+
inputs_embeds=None,
|
1064 |
+
encoder_hidden_states=None,
|
1065 |
+
encoder_attention_mask=None,
|
1066 |
+
labels=None,
|
1067 |
+
use_cache=None,
|
1068 |
+
output_attentions=None,
|
1069 |
+
output_hidden_states=None,
|
1070 |
+
return_dict=None,
|
1071 |
+
):
|
1072 |
+
r"""
|
1073 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1074 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1075 |
+
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
1076 |
+
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
1077 |
+
"""
|
1078 |
+
return_dict = (
|
1079 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1080 |
+
)
|
1081 |
+
|
1082 |
+
transformer_outputs = self.transformer(
|
1083 |
+
input_ids,
|
1084 |
+
past_key_values=past_key_values,
|
1085 |
+
attention_mask=attention_mask,
|
1086 |
+
token_type_ids=token_type_ids,
|
1087 |
+
position_ids=position_ids,
|
1088 |
+
head_mask=head_mask,
|
1089 |
+
inputs_embeds=inputs_embeds,
|
1090 |
+
encoder_hidden_states=encoder_hidden_states,
|
1091 |
+
encoder_attention_mask=encoder_attention_mask,
|
1092 |
+
use_cache=use_cache,
|
1093 |
+
output_attentions=output_attentions,
|
1094 |
+
output_hidden_states=output_hidden_states,
|
1095 |
+
return_dict=return_dict,
|
1096 |
+
)
|
1097 |
+
hidden_states = transformer_outputs[0]
|
1098 |
+
|
1099 |
+
# Set device for model parallelism
|
1100 |
+
if self.model_parallel:
|
1101 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1102 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1103 |
+
|
1104 |
+
lm_logits = self.lm_head(hidden_states)
|
1105 |
+
|
1106 |
+
loss = None
|
1107 |
+
if labels is not None:
|
1108 |
+
# Shift so that tokens < n predict n
|
1109 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1110 |
+
shift_labels = labels[..., 1:].contiguous()
|
1111 |
+
# Flatten the tokens
|
1112 |
+
loss_fct = CrossEntropyLoss()
|
1113 |
+
loss = loss_fct(
|
1114 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
1115 |
+
)
|
1116 |
+
|
1117 |
+
if not return_dict:
|
1118 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1119 |
+
return ((loss,) + output) if loss is not None else output
|
1120 |
+
|
1121 |
+
return CausalLMOutputWithCrossAttentions(
|
1122 |
+
loss=loss,
|
1123 |
+
logits=lm_logits,
|
1124 |
+
past_key_values=transformer_outputs.past_key_values,
|
1125 |
+
hidden_states=transformer_outputs.hidden_states,
|
1126 |
+
attentions=transformer_outputs.attentions,
|
1127 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
@staticmethod
|
1131 |
+
def _reorder_cache(
|
1132 |
+
past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
1133 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
1134 |
+
"""
|
1135 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
1136 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
1137 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
1138 |
+
"""
|
1139 |
+
return tuple(
|
1140 |
+
tuple(
|
1141 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
1142 |
+
for past_state in layer_past
|
1143 |
+
)
|
1144 |
+
for layer_past in past
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
|
1148 |
+
@add_start_docstrings(
|
1149 |
+
"""
|
1150 |
+
The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
1151 |
+
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
|
1152 |
+
input embeddings, the classification head takes as input the input of a specified classification token index in the
|
1153 |
+
input sequence).
|
1154 |
+
""",
|
1155 |
+
GPT2_START_DOCSTRING,
|
1156 |
+
)
|
1157 |
+
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
1158 |
+
def __init__(self, config):
|
1159 |
+
super().__init__(config)
|
1160 |
+
config.num_labels = 1
|
1161 |
+
self.transformer = GPT2Model(config)
|
1162 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1163 |
+
self.multiple_choice_head = SequenceSummary(config)
|
1164 |
+
|
1165 |
+
self.init_weights()
|
1166 |
+
|
1167 |
+
# Model parallel
|
1168 |
+
self.model_parallel = False
|
1169 |
+
self.device_map = None
|
1170 |
+
|
1171 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1172 |
+
def parallelize(self, device_map=None):
|
1173 |
+
self.device_map = (
|
1174 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1175 |
+
if device_map is None
|
1176 |
+
else device_map
|
1177 |
+
)
|
1178 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
1179 |
+
self.transformer.parallelize(self.device_map)
|
1180 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1181 |
+
self.multiple_choice_head = self.multiple_choice_head.to(
|
1182 |
+
self.transformer.first_device
|
1183 |
+
)
|
1184 |
+
self.model_parallel = True
|
1185 |
+
|
1186 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1187 |
+
def deparallelize(self):
|
1188 |
+
self.transformer.deparallelize()
|
1189 |
+
self.transformer = self.transformer.to("cpu")
|
1190 |
+
self.lm_head = self.lm_head.to("cpu")
|
1191 |
+
self.multiple_choice_head = self.multiple_choice_head.to("cpu")
|
1192 |
+
self.model_parallel = False
|
1193 |
+
torch.cuda.empty_cache()
|
1194 |
+
|
1195 |
+
def get_output_embeddings(self):
|
1196 |
+
return self.lm_head
|
1197 |
+
|
1198 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1199 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
1200 |
+
# only last token for inputs_ids if past is defined in kwargs
|
1201 |
+
if past:
|
1202 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1203 |
+
if token_type_ids is not None:
|
1204 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
1205 |
+
|
1206 |
+
attention_mask = kwargs.get("attention_mask", None)
|
1207 |
+
position_ids = kwargs.get("position_ids", None)
|
1208 |
+
|
1209 |
+
if attention_mask is not None and position_ids is None:
|
1210 |
+
# create position_ids on the fly for batch generation
|
1211 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1212 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1213 |
+
if past:
|
1214 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1215 |
+
else:
|
1216 |
+
position_ids = None
|
1217 |
+
|
1218 |
+
return {
|
1219 |
+
"input_ids": input_ids,
|
1220 |
+
"past_key_values": past,
|
1221 |
+
"use_cache": kwargs.get("use_cache"),
|
1222 |
+
"position_ids": position_ids,
|
1223 |
+
"attention_mask": attention_mask,
|
1224 |
+
"token_type_ids": token_type_ids,
|
1225 |
+
}
|
1226 |
+
|
1227 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1228 |
+
@replace_return_docstrings(
|
1229 |
+
output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
|
1230 |
+
)
|
1231 |
+
def forward(
|
1232 |
+
self,
|
1233 |
+
input_ids=None,
|
1234 |
+
past_key_values=None,
|
1235 |
+
attention_mask=None,
|
1236 |
+
token_type_ids=None,
|
1237 |
+
position_ids=None,
|
1238 |
+
head_mask=None,
|
1239 |
+
inputs_embeds=None,
|
1240 |
+
mc_token_ids=None,
|
1241 |
+
labels=None,
|
1242 |
+
mc_labels=None,
|
1243 |
+
use_cache=None,
|
1244 |
+
output_attentions=None,
|
1245 |
+
output_hidden_states=None,
|
1246 |
+
return_dict=None,
|
1247 |
+
**kwargs,
|
1248 |
+
):
|
1249 |
+
r"""
|
1250 |
+
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
|
1251 |
+
Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) -
|
1252 |
+
1[``.
|
1253 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1254 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1255 |
+
``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to
|
1256 |
+
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
1257 |
+
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
|
1258 |
+
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
|
1259 |
+
num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see
|
1260 |
+
`input_ids` above)
|
1261 |
+
|
1262 |
+
Return:
|
1263 |
+
|
1264 |
+
Example::
|
1265 |
+
|
1266 |
+
>>> import torch
|
1267 |
+
>>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
|
1268 |
+
|
1269 |
+
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
1270 |
+
>>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
|
1271 |
+
|
1272 |
+
>>> # Add a [CLS] to the vocabulary (we should train it also!)
|
1273 |
+
>>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
1274 |
+
|
1275 |
+
>>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
|
1276 |
+
|
1277 |
+
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
1278 |
+
>>> encoded_choices = [tokenizer.encode(s) for s in choices]
|
1279 |
+
>>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
|
1280 |
+
|
1281 |
+
>>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
|
1282 |
+
>>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
|
1283 |
+
|
1284 |
+
>>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
1285 |
+
>>> lm_logits = outputs.lm_logits
|
1286 |
+
>>> mc_logits = outputs.mc_logits
|
1287 |
+
|
1288 |
+
"""
|
1289 |
+
return_dict = (
|
1290 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
transformer_outputs = self.transformer(
|
1294 |
+
input_ids,
|
1295 |
+
past_key_values=past_key_values,
|
1296 |
+
attention_mask=attention_mask,
|
1297 |
+
token_type_ids=token_type_ids,
|
1298 |
+
position_ids=position_ids,
|
1299 |
+
head_mask=head_mask,
|
1300 |
+
inputs_embeds=inputs_embeds,
|
1301 |
+
use_cache=use_cache,
|
1302 |
+
output_attentions=output_attentions,
|
1303 |
+
output_hidden_states=output_hidden_states,
|
1304 |
+
return_dict=return_dict,
|
1305 |
+
)
|
1306 |
+
|
1307 |
+
hidden_states = transformer_outputs[0]
|
1308 |
+
|
1309 |
+
# Set device for model parallelism
|
1310 |
+
if self.model_parallel:
|
1311 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1312 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1313 |
+
|
1314 |
+
lm_logits = self.lm_head(hidden_states)
|
1315 |
+
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
1316 |
+
|
1317 |
+
mc_loss = None
|
1318 |
+
if mc_labels is not None:
|
1319 |
+
loss_fct = CrossEntropyLoss()
|
1320 |
+
mc_loss = loss_fct(
|
1321 |
+
mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
|
1322 |
+
)
|
1323 |
+
lm_loss = None
|
1324 |
+
if labels is not None:
|
1325 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1326 |
+
shift_labels = labels[..., 1:].contiguous()
|
1327 |
+
loss_fct = CrossEntropyLoss()
|
1328 |
+
lm_loss = loss_fct(
|
1329 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
if not return_dict:
|
1333 |
+
output = (lm_logits, mc_logits) + transformer_outputs[1:]
|
1334 |
+
if mc_loss is not None:
|
1335 |
+
output = (mc_loss,) + output
|
1336 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1337 |
+
|
1338 |
+
return GPT2DoubleHeadsModelOutput(
|
1339 |
+
loss=lm_loss,
|
1340 |
+
mc_loss=mc_loss,
|
1341 |
+
logits=lm_logits,
|
1342 |
+
mc_logits=mc_logits,
|
1343 |
+
past_key_values=transformer_outputs.past_key_values,
|
1344 |
+
hidden_states=transformer_outputs.hidden_states,
|
1345 |
+
attentions=transformer_outputs.attentions,
|
1346 |
+
)
|
1347 |
+
|
1348 |
+
@staticmethod
|
1349 |
+
def _reorder_cache(
|
1350 |
+
past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
1351 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
1352 |
+
"""
|
1353 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
1354 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
1355 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
1356 |
+
"""
|
1357 |
+
return tuple(
|
1358 |
+
tuple(
|
1359 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
1360 |
+
for past_state in layer_past
|
1361 |
+
)
|
1362 |
+
for layer_past in past
|
1363 |
+
)
|
1364 |
+
|
1365 |
+
|
1366 |
+
@add_start_docstrings(
|
1367 |
+
"""
|
1368 |
+
The GPT2 Model transformer with a sequence classification head on top (linear layer).
|
1369 |
+
|
1370 |
+
:class:`~transformers.GPT2ForSequenceClassification` uses the last token in order to do the classification, as
|
1371 |
+
other causal models (e.g. GPT-1) do.
|
1372 |
+
|
1373 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1374 |
+
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
|
1375 |
+
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
1376 |
+
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
|
1377 |
+
the last value in each row of the batch).
|
1378 |
+
""",
|
1379 |
+
GPT2_START_DOCSTRING,
|
1380 |
+
)
|
1381 |
+
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
1382 |
+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
1383 |
+
|
1384 |
+
def __init__(self, config):
|
1385 |
+
super().__init__(config)
|
1386 |
+
self.num_labels = config.num_labels
|
1387 |
+
self.transformer = GPT2Model(config)
|
1388 |
+
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1389 |
+
|
1390 |
+
self.init_weights()
|
1391 |
+
|
1392 |
+
# Model parallel
|
1393 |
+
self.model_parallel = False
|
1394 |
+
self.device_map = None
|
1395 |
+
|
1396 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1397 |
+
@add_code_sample_docstrings(
|
1398 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1399 |
+
checkpoint="microsoft/dialogrpt",
|
1400 |
+
output_type=SequenceClassifierOutputWithPast,
|
1401 |
+
config_class=_CONFIG_FOR_DOC,
|
1402 |
+
)
|
1403 |
+
def forward(
|
1404 |
+
self,
|
1405 |
+
input_ids=None,
|
1406 |
+
past_key_values=None,
|
1407 |
+
attention_mask=None,
|
1408 |
+
token_type_ids=None,
|
1409 |
+
position_ids=None,
|
1410 |
+
head_mask=None,
|
1411 |
+
inputs_embeds=None,
|
1412 |
+
labels=None,
|
1413 |
+
use_cache=None,
|
1414 |
+
output_attentions=None,
|
1415 |
+
output_hidden_states=None,
|
1416 |
+
return_dict=None,
|
1417 |
+
):
|
1418 |
+
r"""
|
1419 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1420 |
+
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
1421 |
+
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1422 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1423 |
+
"""
|
1424 |
+
return_dict = (
|
1425 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1426 |
+
)
|
1427 |
+
|
1428 |
+
transformer_outputs = self.transformer(
|
1429 |
+
input_ids,
|
1430 |
+
past_key_values=past_key_values,
|
1431 |
+
attention_mask=attention_mask,
|
1432 |
+
token_type_ids=token_type_ids,
|
1433 |
+
position_ids=position_ids,
|
1434 |
+
head_mask=head_mask,
|
1435 |
+
inputs_embeds=inputs_embeds,
|
1436 |
+
use_cache=use_cache,
|
1437 |
+
output_attentions=output_attentions,
|
1438 |
+
output_hidden_states=output_hidden_states,
|
1439 |
+
return_dict=return_dict,
|
1440 |
+
)
|
1441 |
+
hidden_states = transformer_outputs[0]
|
1442 |
+
logits = self.score(hidden_states)
|
1443 |
+
|
1444 |
+
if input_ids is not None:
|
1445 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
1446 |
+
else:
|
1447 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
1448 |
+
|
1449 |
+
assert (
|
1450 |
+
self.config.pad_token_id is not None or batch_size == 1
|
1451 |
+
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
1452 |
+
if self.config.pad_token_id is None:
|
1453 |
+
sequence_lengths = -1
|
1454 |
+
else:
|
1455 |
+
if input_ids is not None:
|
1456 |
+
sequence_lengths = (
|
1457 |
+
torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
1458 |
+
)
|
1459 |
+
else:
|
1460 |
+
sequence_lengths = -1
|
1461 |
+
logger.warning(
|
1462 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
1463 |
+
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
1464 |
+
)
|
1465 |
+
|
1466 |
+
pooled_logits = logits[range(batch_size), sequence_lengths]
|
1467 |
+
|
1468 |
+
loss = None
|
1469 |
+
if labels is not None:
|
1470 |
+
if self.num_labels == 1:
|
1471 |
+
# We are doing regression
|
1472 |
+
loss_fct = MSELoss()
|
1473 |
+
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
1474 |
+
else:
|
1475 |
+
loss_fct = CrossEntropyLoss()
|
1476 |
+
loss = loss_fct(
|
1477 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
if not return_dict:
|
1481 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1482 |
+
return ((loss,) + output) if loss is not None else output
|
1483 |
+
|
1484 |
+
return SequenceClassifierOutputWithPast(
|
1485 |
+
loss=loss,
|
1486 |
+
logits=pooled_logits,
|
1487 |
+
past_key_values=transformer_outputs.past_key_values,
|
1488 |
+
hidden_states=transformer_outputs.hidden_states,
|
1489 |
+
attentions=transformer_outputs.attentions,
|
1490 |
+
)
|
1491 |
+
|
1492 |
+
|
1493 |
+
@add_start_docstrings(
|
1494 |
+
"""
|
1495 |
+
GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1496 |
+
Named-Entity-Recognition (NER) tasks.
|
1497 |
+
""",
|
1498 |
+
GPT2_START_DOCSTRING,
|
1499 |
+
)
|
1500 |
+
class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
1501 |
+
def __init__(self, config):
|
1502 |
+
super().__init__(config)
|
1503 |
+
self.num_labels = config.num_labels
|
1504 |
+
|
1505 |
+
self.transformer = GPT2Model(config)
|
1506 |
+
if (
|
1507 |
+
hasattr(config, "classifier_dropout")
|
1508 |
+
and config.classifier_dropout is not None
|
1509 |
+
):
|
1510 |
+
classifier_dropout = config.classifier_dropout
|
1511 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1512 |
+
classifier_dropout = config.hidden_dropout
|
1513 |
+
else:
|
1514 |
+
classifier_dropout = 0.1
|
1515 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1516 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1517 |
+
|
1518 |
+
self.init_weights()
|
1519 |
+
|
1520 |
+
# Model parallel
|
1521 |
+
self.model_parallel = False
|
1522 |
+
self.device_map = None
|
1523 |
+
|
1524 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1525 |
+
@add_code_sample_docstrings(
|
1526 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1527 |
+
checkpoint="microsoft/DialogRPT-updown",
|
1528 |
+
output_type=TokenClassifierOutput,
|
1529 |
+
config_class=_CONFIG_FOR_DOC,
|
1530 |
+
)
|
1531 |
+
def forward(
|
1532 |
+
self,
|
1533 |
+
input_ids=None,
|
1534 |
+
past_key_values=None,
|
1535 |
+
attention_mask=None,
|
1536 |
+
token_type_ids=None,
|
1537 |
+
position_ids=None,
|
1538 |
+
head_mask=None,
|
1539 |
+
inputs_embeds=None,
|
1540 |
+
labels=None,
|
1541 |
+
use_cache=None,
|
1542 |
+
output_attentions=None,
|
1543 |
+
output_hidden_states=None,
|
1544 |
+
return_dict=None,
|
1545 |
+
):
|
1546 |
+
r"""
|
1547 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1548 |
+
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
1549 |
+
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1550 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1551 |
+
"""
|
1552 |
+
return_dict = (
|
1553 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1554 |
+
)
|
1555 |
+
|
1556 |
+
transformer_outputs = self.transformer(
|
1557 |
+
input_ids,
|
1558 |
+
past_key_values=past_key_values,
|
1559 |
+
attention_mask=attention_mask,
|
1560 |
+
token_type_ids=token_type_ids,
|
1561 |
+
position_ids=position_ids,
|
1562 |
+
head_mask=head_mask,
|
1563 |
+
inputs_embeds=inputs_embeds,
|
1564 |
+
use_cache=use_cache,
|
1565 |
+
output_attentions=output_attentions,
|
1566 |
+
output_hidden_states=output_hidden_states,
|
1567 |
+
return_dict=return_dict,
|
1568 |
+
)
|
1569 |
+
|
1570 |
+
hidden_states = transformer_outputs[0]
|
1571 |
+
hidden_states = self.dropout(hidden_states)
|
1572 |
+
logits = self.classifier(hidden_states)
|
1573 |
+
|
1574 |
+
loss = None
|
1575 |
+
if labels is not None:
|
1576 |
+
loss_fct = CrossEntropyLoss()
|
1577 |
+
# Only keep active parts of the loss
|
1578 |
+
if attention_mask is not None:
|
1579 |
+
active_loss = attention_mask.view(-1) == 1
|
1580 |
+
active_logits = logits.view(-1, self.num_labels)
|
1581 |
+
active_labels = torch.where(
|
1582 |
+
active_loss,
|
1583 |
+
labels.view(-1),
|
1584 |
+
torch.tensor(loss_fct.ignore_index).type_as(labels),
|
1585 |
+
)
|
1586 |
+
loss = loss_fct(active_logits, active_labels)
|
1587 |
+
else:
|
1588 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1589 |
+
|
1590 |
+
if not return_dict:
|
1591 |
+
output = (logits,) + transformer_outputs[2:]
|
1592 |
+
return ((loss,) + output) if loss is not None else output
|
1593 |
+
|
1594 |
+
return TokenClassifierOutput(
|
1595 |
+
loss=loss,
|
1596 |
+
logits=logits,
|
1597 |
+
hidden_states=transformer_outputs.hidden_states,
|
1598 |
+
attentions=transformer_outputs.attentions,
|
1599 |
+
)
|
backend/preprocess.py
ADDED
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from typing import List
|
5 |
+
from farasa.segmenter import FarasaSegmenter
|
6 |
+
import emoji
|
7 |
+
|
8 |
+
import pyarabic.araby as araby
|
9 |
+
|
10 |
+
ACCEPTED_MODELS = [
|
11 |
+
"bert-base-arabertv01",
|
12 |
+
"bert-base-arabert",
|
13 |
+
"bert-base-arabertv02",
|
14 |
+
"bert-base-arabertv2",
|
15 |
+
"bert-large-arabertv02",
|
16 |
+
"bert-large-arabertv2",
|
17 |
+
"araelectra-base",
|
18 |
+
"araelectra-base-discriminator",
|
19 |
+
"araelectra-base-generator",
|
20 |
+
"araelectra-base-artydiqa",
|
21 |
+
"aragpt2-base",
|
22 |
+
"aragpt2-medium",
|
23 |
+
"aragpt2-large",
|
24 |
+
"aragpt2-mega",
|
25 |
+
]
|
26 |
+
|
27 |
+
SEGMENTED_MODELS = [
|
28 |
+
"bert-base-arabert",
|
29 |
+
"bert-base-arabertv2",
|
30 |
+
"bert-large-arabertv2",
|
31 |
+
]
|
32 |
+
|
33 |
+
SECOND_GEN_MODELS = [
|
34 |
+
"bert-base-arabertv02",
|
35 |
+
"bert-base-arabertv2",
|
36 |
+
"bert-large-arabertv02",
|
37 |
+
"bert-large-arabertv2",
|
38 |
+
"araelectra-base",
|
39 |
+
"araelectra-base-discriminator",
|
40 |
+
"araelectra-base-generator",
|
41 |
+
"araelectra-base-artydiqa",
|
42 |
+
"aragpt2-base",
|
43 |
+
"aragpt2-medium",
|
44 |
+
"aragpt2-large",
|
45 |
+
"aragpt2-mega",
|
46 |
+
]
|
47 |
+
|
48 |
+
farasa_segmenter = FarasaSegmenter(interactive=True)
|
49 |
+
|
50 |
+
|
51 |
+
class ArabertPreprocessor:
|
52 |
+
"""
|
53 |
+
A Preprocessor class that cleans and preprocesses text for all models in the AraBERT repo.
|
54 |
+
It also can unprocess the text ouput of the generated text
|
55 |
+
|
56 |
+
Args:
|
57 |
+
|
58 |
+
model_name (:obj:`str`): model name from the HuggingFace Models page without
|
59 |
+
the aubmindlab tag. Will default to a base Arabic preprocessor if model name was not found.
|
60 |
+
Current accepted models are:
|
61 |
+
|
62 |
+
- "bert-base-arabertv01": No farasa segmentation.
|
63 |
+
- "bert-base-arabert": with farasa segmentation.
|
64 |
+
- "bert-base-arabertv02": No farasas egmentation.
|
65 |
+
- "bert-base-arabertv2": with farasa segmentation.
|
66 |
+
- "bert-large-arabertv02": No farasas egmentation.
|
67 |
+
- "bert-large-arabertv2": with farasa segmentation.
|
68 |
+
- "araelectra-base": No farasa segmentation.
|
69 |
+
- "araelectra-base-discriminator": No farasa segmentation.
|
70 |
+
- "araelectra-base-generator": No farasa segmentation.
|
71 |
+
- "aragpt2-base": No farasa segmentation.
|
72 |
+
- "aragpt2-medium": No farasa segmentation.
|
73 |
+
- "aragpt2-large": No farasa segmentation.
|
74 |
+
- "aragpt2-mega": No farasa segmentation.
|
75 |
+
|
76 |
+
|
77 |
+
keep_emojis(:obj:`bool`, `optional`, defaults to :obj:`False`): don't remove emojis while preprocessing.
|
78 |
+
|
79 |
+
remove_html_markup(:obj: `bool`, `optional`, defaults to :obj:`True`): Whether to remove html artfacts,
|
80 |
+
should be set to False when preprocessing TyDi QA.
|
81 |
+
|
82 |
+
replace_urls_emails_mentions(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to replace email urls
|
83 |
+
and mentions by special tokens.
|
84 |
+
|
85 |
+
strip_tashkeel(:obj:`bool`, `optional`, defaults to :obj:`True`): remove diacritics (FATHATAN, DAMMATAN, KASRATAN, FATHA, DAMMA,
|
86 |
+
KASRA, SUKUN, SHADDA).
|
87 |
+
|
88 |
+
strip_tatweel(:obj:`bool`, `optional`, defaults to :obj:`True`): remove tatweel '\\u0640'.
|
89 |
+
|
90 |
+
insert_white_spaces(:obj:`bool`, `optional`, defaults to :obj:`True`): insert whitespace before and after all non Arabic digits
|
91 |
+
or English digits or Arabic and English Alphabet or the 2 brackets, then inserts whitespace
|
92 |
+
between words and numbers or numbers and words.
|
93 |
+
|
94 |
+
remove_non_digit_repetition(:obj:`bool`, `optional`, defaults to :obj:`True`): replace repetition of more than 2 non-digit character with
|
95 |
+
2 of this character.
|
96 |
+
|
97 |
+
replace_slash_with_dash(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in AraBERTv02,
|
98 |
+
AraELECTRA and AraGPT2.
|
99 |
+
Set to False to force disable, and True to force enable. Replaces the "/" with "-",
|
100 |
+
since "/" is missing from AraBERTv2, AraELECTRA and ARAGPT2 vocabulary.
|
101 |
+
|
102 |
+
map_hindi_numbers_to_arabic(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in
|
103 |
+
AraBERTv02, AraELECTRA and AraGPT2.Set to False to force disable, and True to force enable.
|
104 |
+
Replaces hindi numbers with the corresponding Arabic one. ex: "١٩٩٥" --> "1995".
|
105 |
+
This is behavior is present by default in AraBERTv1 and v2 (with pre-segmentation),
|
106 |
+
and fixes the issue of caused by a bug when inserting white spaces.
|
107 |
+
|
108 |
+
apply_farasa_segmentation(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in
|
109 |
+
AraBERTv2, and AraBERTv1. Set to False to force disable, and True to force enable.
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
|
115 |
+
ArabertPreprocessor: A preprocessor instance
|
116 |
+
|
117 |
+
Example:
|
118 |
+
|
119 |
+
from preprocess import ArabertPreprocessor
|
120 |
+
|
121 |
+
arabert_prep = ArabertPreprocessor("aubmindlab/bert-base-arabertv2")
|
122 |
+
|
123 |
+
arabert_prep.preprocess("SOME ARABIC TEXT")
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
model_name: str,
|
129 |
+
keep_emojis: bool = False,
|
130 |
+
remove_html_markup: bool = True,
|
131 |
+
replace_urls_emails_mentions: bool = True,
|
132 |
+
strip_tashkeel: bool = True,
|
133 |
+
strip_tatweel: bool = True,
|
134 |
+
insert_white_spaces: bool = True,
|
135 |
+
remove_non_digit_repetition: bool = True,
|
136 |
+
replace_slash_with_dash: bool = None,
|
137 |
+
map_hindi_numbers_to_arabic: bool = None,
|
138 |
+
apply_farasa_segmentation: bool = None,
|
139 |
+
):
|
140 |
+
|
141 |
+
model_name = model_name.replace("aubmindlab/", "").replace("wissamantoun/", "")
|
142 |
+
|
143 |
+
if model_name not in ACCEPTED_MODELS:
|
144 |
+
logging.warning(
|
145 |
+
"""Model provided is not in the accepted model list. Preprocessor will default to a base Arabic preprocessor"""
|
146 |
+
)
|
147 |
+
self.model_name = "bert-base-arabertv02"
|
148 |
+
else:
|
149 |
+
self.model_name = model_name
|
150 |
+
|
151 |
+
if apply_farasa_segmentation is None:
|
152 |
+
if self.model_name in SEGMENTED_MODELS:
|
153 |
+
self.apply_farasa_segmentation = True
|
154 |
+
else:
|
155 |
+
self.apply_farasa_segmentation = False
|
156 |
+
else:
|
157 |
+
if (
|
158 |
+
apply_farasa_segmentation == False
|
159 |
+
and self.model_name in SEGMENTED_MODELS
|
160 |
+
):
|
161 |
+
logging.warning(
|
162 |
+
"The selected model_name requires Farasa pre-segmentation, but apply_farasa_segmentation was set to False!"
|
163 |
+
)
|
164 |
+
|
165 |
+
self.apply_farasa_segmentation = apply_farasa_segmentation
|
166 |
+
|
167 |
+
self.keep_emojis = keep_emojis
|
168 |
+
self.remove_html_markup = remove_html_markup
|
169 |
+
self.replace_urls_emails_mentions = replace_urls_emails_mentions
|
170 |
+
self.strip_tashkeel = strip_tashkeel
|
171 |
+
self.strip_tatweel = strip_tatweel
|
172 |
+
self.insert_white_spaces = insert_white_spaces
|
173 |
+
self.remove_non_digit_repetition = remove_non_digit_repetition
|
174 |
+
|
175 |
+
if replace_slash_with_dash is None:
|
176 |
+
if self.model_name in SECOND_GEN_MODELS:
|
177 |
+
self.replace_slash_with_dash = True
|
178 |
+
else:
|
179 |
+
self.replace_slash_with_dash = False
|
180 |
+
else:
|
181 |
+
self.replace_slash_with_dash = replace_slash_with_dash
|
182 |
+
|
183 |
+
if map_hindi_numbers_to_arabic is None:
|
184 |
+
if self.model_name in SECOND_GEN_MODELS:
|
185 |
+
self.map_hindi_numbers_to_arabic = True
|
186 |
+
else:
|
187 |
+
self.map_hindi_numbers_to_arabic = False
|
188 |
+
else:
|
189 |
+
self.map_hindi_numbers_to_arabic = map_hindi_numbers_to_arabic
|
190 |
+
|
191 |
+
def preprocess(self, text: str) -> str:
|
192 |
+
"""
|
193 |
+
Preprocess takes an input text line an applies the same preprocessing used in AraBERT
|
194 |
+
pretraining, or according to settings
|
195 |
+
|
196 |
+
Args:
|
197 |
+
|
198 |
+
text (:obj:`str`): inout text string
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
|
202 |
+
string: A preprocessed string depending on which model was selected
|
203 |
+
"""
|
204 |
+
if (
|
205 |
+
self.model_name == "bert-base-arabert"
|
206 |
+
or self.model_name == "bert-base-arabertv01"
|
207 |
+
):
|
208 |
+
return self._preprocess_v1(
|
209 |
+
text,
|
210 |
+
do_farasa_tokenization=self.apply_farasa_segmentation,
|
211 |
+
)
|
212 |
+
|
213 |
+
if self.model_name in SECOND_GEN_MODELS:
|
214 |
+
return self._preprocess_v2(text)
|
215 |
+
|
216 |
+
return self._preprocess_v3(text)
|
217 |
+
|
218 |
+
def unpreprocess(self, text: str, desegment: bool = True) -> str:
|
219 |
+
"""Re-formats the text to a classic format where punctuations, brackets, parenthesis are not seperated by whitespaces.
|
220 |
+
The objective is to make the generated text of any model appear natural and not preprocessed.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
text (:obj:`str`): input text to be un-preprocessed
|
224 |
+
desegment (:obj:`bool`, optional): [whether or not to remove farasa pre-segmentation before]..
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
str: The unpreprocessed (and possibly Farasa-desegmented) text.
|
228 |
+
"""
|
229 |
+
|
230 |
+
if self.apply_farasa_segmentation and desegment:
|
231 |
+
text = self.desegment(text)
|
232 |
+
|
233 |
+
# removes the spaces around quotation marks ex: i " ate " an apple --> i "ate" an apple
|
234 |
+
# https://stackoverflow.com/a/53436792/5381220
|
235 |
+
text = re.sub(white_spaced_double_quotation_regex, '"' + r"\1" + '"', text)
|
236 |
+
text = re.sub(white_spaced_single_quotation_regex, "'" + r"\1" + "'", text)
|
237 |
+
text = re.sub(white_spaced_back_quotation_regex, "\`" + r"\1" + "\`", text)
|
238 |
+
text = re.sub(white_spaced_back_quotation_regex, "\—" + r"\1" + "\—", text)
|
239 |
+
|
240 |
+
# during generation, sometimes the models don't put a space after the dot, this handles it
|
241 |
+
text = text.replace(".", " . ")
|
242 |
+
text = " ".join(text.split())
|
243 |
+
|
244 |
+
# handle decimals
|
245 |
+
text = re.sub(r"(\d+) \. (\d+)", r"\1.\2", text)
|
246 |
+
text = re.sub(r"(\d+) \, (\d+)", r"\1,\2", text)
|
247 |
+
|
248 |
+
text = re.sub(left_and_right_spaced_chars, r"\1", text)
|
249 |
+
text = re.sub(left_spaced_chars, r"\1", text)
|
250 |
+
text = re.sub(right_spaced_chars, r"\1", text)
|
251 |
+
|
252 |
+
return text
|
253 |
+
|
254 |
+
def desegment(self, text: str) -> str:
|
255 |
+
"""
|
256 |
+
Use this function if sentence tokenization was done using
|
257 |
+
`from arabert.preprocess_arabert import preprocess` with Farasa enabled
|
258 |
+
AraBERT segmentation using Farasa adds a space after the '+' for prefixes,
|
259 |
+
and after before the '+' for suffixes
|
260 |
+
|
261 |
+
Example:
|
262 |
+
>>> desegment('ال+ دراس +ات')
|
263 |
+
الدراسات
|
264 |
+
"""
|
265 |
+
text = text.replace("+ ", "+")
|
266 |
+
text = text.replace(" +", "+")
|
267 |
+
text = " ".join([self._desegmentword(word) for word in text.split(" ")])
|
268 |
+
return text
|
269 |
+
|
270 |
+
def _desegmentword(self, orig_word: str) -> str:
|
271 |
+
"""
|
272 |
+
Word segmentor that takes a Farasa Segmented Word and removes the '+' signs
|
273 |
+
|
274 |
+
Example:
|
275 |
+
>>> _desegmentword("ال+يومي+ة")
|
276 |
+
اليومية
|
277 |
+
"""
|
278 |
+
word = orig_word.replace("ل+ال+", "لل")
|
279 |
+
if "ال+ال" not in orig_word:
|
280 |
+
word = word.replace("ل+ال", "لل")
|
281 |
+
word = word.replace("+", "")
|
282 |
+
word = word.replace("للل", "لل")
|
283 |
+
return word
|
284 |
+
|
285 |
+
def _preprocess_v3(self, text: str) -> str:
|
286 |
+
text = str(text)
|
287 |
+
text = html.unescape(text)
|
288 |
+
if self.strip_tashkeel:
|
289 |
+
text = araby.strip_tashkeel(text)
|
290 |
+
if self.strip_tatweel:
|
291 |
+
text = araby.strip_tatweel(text)
|
292 |
+
|
293 |
+
if self.replace_urls_emails_mentions:
|
294 |
+
# replace all possible URLs
|
295 |
+
for reg in url_regexes:
|
296 |
+
text = re.sub(reg, " [رابط] ", text)
|
297 |
+
# REplace Emails with [بريد]
|
298 |
+
for reg in email_regexes:
|
299 |
+
text = re.sub(reg, " [بريد] ", text)
|
300 |
+
# replace mentions with [مستخدم]
|
301 |
+
text = re.sub(user_mention_regex, " [مستخدم] ", text)
|
302 |
+
|
303 |
+
if self.remove_html_markup:
|
304 |
+
# remove html line breaks
|
305 |
+
text = re.sub("<br />", " ", text)
|
306 |
+
# remove html markup
|
307 |
+
text = re.sub("</?[^>]+>", " ", text)
|
308 |
+
|
309 |
+
if self.map_hindi_numbers_to_arabic:
|
310 |
+
text = text.translate(hindi_to_arabic_map)
|
311 |
+
|
312 |
+
# remove repeated characters >2
|
313 |
+
if self.remove_non_digit_repetition:
|
314 |
+
text = self._remove_non_digit_repetition(text)
|
315 |
+
|
316 |
+
# insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
|
317 |
+
if self.insert_white_spaces:
|
318 |
+
text = re.sub(
|
319 |
+
"([^0-9\u0621-\u063A\u0641-\u064A\u0660-\u0669a-zA-Z ])",
|
320 |
+
r" \1 ",
|
321 |
+
text,
|
322 |
+
)
|
323 |
+
|
324 |
+
# re-fix brackets
|
325 |
+
text = text.replace("[ رابط ]", "[رابط]")
|
326 |
+
text = text.replace("[ بريد ]", "[بريد]")
|
327 |
+
text = text.replace("[ مستخدم ]", "[مستخدم]")
|
328 |
+
|
329 |
+
# insert whitespace between words and numbers or numbers and words
|
330 |
+
text = re.sub(
|
331 |
+
"(\d+)([\u0621-\u063A\u0641-\u064A\u066A-\u066C\u0654-\u0655]+)",
|
332 |
+
r" \1 \2 ",
|
333 |
+
text,
|
334 |
+
)
|
335 |
+
text = re.sub(
|
336 |
+
"([\u0621-\u063A\u0641-\u064A\u066A-\u066C\u0654-\u0655]+)(\d+)",
|
337 |
+
r" \1 \2 ",
|
338 |
+
text,
|
339 |
+
)
|
340 |
+
|
341 |
+
# remove unwanted characters
|
342 |
+
if self.keep_emojis:
|
343 |
+
emoji_regex = "".join(list(emoji.UNICODE_EMOJI["en"].keys()))
|
344 |
+
rejected_chars_regex2 = "[^%s%s]" % (chars_regexv2, emoji_regex)
|
345 |
+
text = re.sub(rejected_chars_regex2, " ", text)
|
346 |
+
else:
|
347 |
+
text = re.sub(rejected_chars_regexv2, " ", text)
|
348 |
+
|
349 |
+
# remove extra spaces
|
350 |
+
text = " ".join(text.replace("\uFE0F", "").split())
|
351 |
+
|
352 |
+
if self.apply_farasa_segmentation:
|
353 |
+
if self.keep_emojis:
|
354 |
+
new_text = []
|
355 |
+
for word in text.split():
|
356 |
+
if word in list(emoji.UNICODE_EMOJI["en"].keys()):
|
357 |
+
new_text.append(word)
|
358 |
+
else:
|
359 |
+
new_text.append(farasa_segmenter.segment(word))
|
360 |
+
text = " ".join(new_text)
|
361 |
+
else:
|
362 |
+
text = farasa_segmenter.segment(text)
|
363 |
+
return self._farasa_segment(text)
|
364 |
+
|
365 |
+
# ALl the other models dont require Farasa Segmentation
|
366 |
+
return text
|
367 |
+
|
368 |
+
def _preprocess_v2(self, text: str) -> str:
|
369 |
+
text = str(text)
|
370 |
+
text = html.unescape(text)
|
371 |
+
if self.strip_tashkeel:
|
372 |
+
text = araby.strip_tashkeel(text)
|
373 |
+
if self.strip_tatweel:
|
374 |
+
text = araby.strip_tatweel(text)
|
375 |
+
|
376 |
+
if self.replace_urls_emails_mentions:
|
377 |
+
# replace all possible URLs
|
378 |
+
for reg in url_regexes:
|
379 |
+
text = re.sub(reg, " [رابط] ", text)
|
380 |
+
# REplace Emails with [بريد]
|
381 |
+
for reg in email_regexes:
|
382 |
+
text = re.sub(reg, " [بريد] ", text)
|
383 |
+
# replace mentions with [مستخدم]
|
384 |
+
text = re.sub(user_mention_regex, " [مستخدم] ", text)
|
385 |
+
|
386 |
+
if self.remove_html_markup:
|
387 |
+
# remove html line breaks
|
388 |
+
text = re.sub("<br />", " ", text)
|
389 |
+
# remove html markup
|
390 |
+
text = re.sub("</?[^>]+>", " ", text)
|
391 |
+
|
392 |
+
if self.map_hindi_numbers_to_arabic:
|
393 |
+
text = text.translate(hindi_to_arabic_map)
|
394 |
+
|
395 |
+
# remove repeated characters >2
|
396 |
+
if self.remove_non_digit_repetition:
|
397 |
+
text = self._remove_non_digit_repetition(text)
|
398 |
+
|
399 |
+
# insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
|
400 |
+
if self.insert_white_spaces:
|
401 |
+
text = re.sub(
|
402 |
+
"([^0-9\u0621-\u063A\u0641-\u064A\u0660-\u0669a-zA-Z\[\]])",
|
403 |
+
r" \1 ",
|
404 |
+
text,
|
405 |
+
)
|
406 |
+
|
407 |
+
# insert whitespace between words and numbers or numbers and words
|
408 |
+
text = re.sub(
|
409 |
+
"(\d+)([\u0621-\u063A\u0641-\u064A\u0660-\u066C]+)", r" \1 \2 ", text
|
410 |
+
)
|
411 |
+
text = re.sub(
|
412 |
+
"([\u0621-\u063A\u0641-\u064A\u0660-\u066C]+)(\d+)", r" \1 \2 ", text
|
413 |
+
)
|
414 |
+
|
415 |
+
if self.replace_slash_with_dash:
|
416 |
+
text = text.replace("/", "-")
|
417 |
+
|
418 |
+
# remove unwanted characters
|
419 |
+
if self.keep_emojis:
|
420 |
+
emoji_regex = "".join(list(emoji.UNICODE_EMOJI["en"].keys()))
|
421 |
+
rejected_chars_regex2 = "[^%s%s]" % (chars_regex, emoji_regex)
|
422 |
+
text = re.sub(rejected_chars_regex2, " ", text)
|
423 |
+
else:
|
424 |
+
text = re.sub(rejected_chars_regex, " ", text)
|
425 |
+
|
426 |
+
# remove extra spaces
|
427 |
+
text = " ".join(text.replace("\uFE0F", "").split())
|
428 |
+
|
429 |
+
if (
|
430 |
+
self.model_name == "bert-base-arabertv2"
|
431 |
+
or self.model_name == "bert-large-arabertv2"
|
432 |
+
):
|
433 |
+
if self.keep_emojis:
|
434 |
+
new_text = []
|
435 |
+
for word in text.split():
|
436 |
+
if word in list(emoji.UNICODE_EMOJI["en"].keys()):
|
437 |
+
new_text.append(word)
|
438 |
+
else:
|
439 |
+
new_text.append(farasa_segmenter.segment(word))
|
440 |
+
text = " ".join(new_text)
|
441 |
+
else:
|
442 |
+
text = farasa_segmenter.segment(text)
|
443 |
+
return self._farasa_segment(text)
|
444 |
+
|
445 |
+
# ALl the other models dont require Farasa Segmentation
|
446 |
+
return text
|
447 |
+
|
448 |
+
def _preprocess_v1(self, text: str, do_farasa_tokenization: bool) -> str:
|
449 |
+
"""
|
450 |
+
AraBERTv1 preprocessing Function
|
451 |
+
"""
|
452 |
+
text = str(text)
|
453 |
+
if self.strip_tashkeel:
|
454 |
+
text = araby.strip_tashkeel(text)
|
455 |
+
|
456 |
+
text = re.sub(r"\d+\/[ء-ي]+\/\d+\]", "", text)
|
457 |
+
text = re.sub("ـ", "", text)
|
458 |
+
text = re.sub("[«»]", ' " ', text)
|
459 |
+
|
460 |
+
if self.replace_urls_emails_mentions:
|
461 |
+
# replace the [رابط] token with space if you want to clean links
|
462 |
+
text = re.sub(regex_url_step1, "[رابط]", text)
|
463 |
+
text = re.sub(regex_url_step2, "[رابط]", text)
|
464 |
+
text = re.sub(regex_url, "[رابط]", text)
|
465 |
+
text = re.sub(regex_email, "[بريد]", text)
|
466 |
+
text = re.sub(regex_mention, "[مستخدم]", text)
|
467 |
+
text = re.sub("…", r"\.", text).strip()
|
468 |
+
text = self._remove_redundant_punct(text)
|
469 |
+
|
470 |
+
if self.replace_urls_emails_mentions:
|
471 |
+
text = re.sub(r"\[ رابط \]|\[ رابط\]|\[رابط \]", " [رابط] ", text)
|
472 |
+
text = re.sub(r"\[ بريد \]|\[ بريد\]|\[بريد \]", " [بريد] ", text)
|
473 |
+
text = re.sub(r"\[ مستخدم \]|\[ مستخدم\]|\[مستخدم \]", " [مستخدم] ", text)
|
474 |
+
|
475 |
+
if self.remove_non_digit_repetition:
|
476 |
+
text = self._remove_non_digit_repetition(text)
|
477 |
+
|
478 |
+
if self.insert_white_spaces:
|
479 |
+
text = re.sub(
|
480 |
+
"([^0-9\u0621-\u063A\u0641-\u0669\u0671-\u0673a-zA-Z\[\]])",
|
481 |
+
r" \1 ",
|
482 |
+
text,
|
483 |
+
)
|
484 |
+
if do_farasa_tokenization:
|
485 |
+
text = self._tokenize_arabic_words_farasa(text)
|
486 |
+
|
487 |
+
text = " ".join(text.split())
|
488 |
+
|
489 |
+
return text
|
490 |
+
|
491 |
+
def _farasa_segment(self, text: str) -> str:
|
492 |
+
line_farasa = text.split()
|
493 |
+
segmented_line = []
|
494 |
+
for index, word in enumerate(line_farasa):
|
495 |
+
if word in ["[", "]"]:
|
496 |
+
continue
|
497 |
+
if word in ["رابط", "بريد", "مستخدم"] and line_farasa[index - 1] in [
|
498 |
+
"[",
|
499 |
+
"]",
|
500 |
+
]:
|
501 |
+
segmented_line.append("[" + word + "]")
|
502 |
+
continue
|
503 |
+
if "+" not in word:
|
504 |
+
segmented_line.append(word)
|
505 |
+
continue
|
506 |
+
segmented_word = self._split_farasa_output(word)
|
507 |
+
segmented_line.extend(segmented_word)
|
508 |
+
|
509 |
+
return " ".join(segmented_line)
|
510 |
+
|
511 |
+
def _split_farasa_output(self, word: str) -> str:
|
512 |
+
segmented_word = []
|
513 |
+
temp_token = ""
|
514 |
+
for i, c in enumerate(word):
|
515 |
+
if c == "+":
|
516 |
+
# if the token is KAF, it could be a suffix or prefix
|
517 |
+
if temp_token == "ك":
|
518 |
+
# if we are at the second token, then KAF is surely a prefix
|
519 |
+
if i == 1:
|
520 |
+
segmented_word.append(temp_token + "+")
|
521 |
+
temp_token = ""
|
522 |
+
# If the KAF token is between 2 tokens
|
523 |
+
elif word[i - 2] == "+":
|
524 |
+
# if the previous token is prefix, then this KAF must be a prefix
|
525 |
+
if segmented_word[-1][-1] == "+":
|
526 |
+
segmented_word.append(temp_token + "+")
|
527 |
+
temp_token = ""
|
528 |
+
# else it is a suffix, this KAF could not be a second suffix
|
529 |
+
else:
|
530 |
+
segmented_word.append("+" + temp_token)
|
531 |
+
temp_token = ""
|
532 |
+
# if Kaf is at the end, this is handled with the statement after the loop
|
533 |
+
elif temp_token in prefix_list:
|
534 |
+
segmented_word.append(temp_token + "+")
|
535 |
+
temp_token = ""
|
536 |
+
elif temp_token in suffix_list:
|
537 |
+
segmented_word.append("+" + temp_token)
|
538 |
+
temp_token = ""
|
539 |
+
else:
|
540 |
+
segmented_word.append(temp_token)
|
541 |
+
temp_token = ""
|
542 |
+
continue
|
543 |
+
temp_token += c
|
544 |
+
if temp_token != "":
|
545 |
+
if temp_token in suffix_list:
|
546 |
+
segmented_word.append("+" + temp_token)
|
547 |
+
else:
|
548 |
+
segmented_word.append(temp_token)
|
549 |
+
return segmented_word
|
550 |
+
|
551 |
+
def _tokenize_arabic_words_farasa(self, line_input: str) -> str:
|
552 |
+
|
553 |
+
if self.keep_emojis:
|
554 |
+
# insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
|
555 |
+
line_farasa = []
|
556 |
+
for word in line_input.split():
|
557 |
+
if word in list(emoji.UNICODE_EMOJI["en"].keys()):
|
558 |
+
line_farasa.append(word)
|
559 |
+
else:
|
560 |
+
line_farasa.append(farasa_segmenter.segment(word))
|
561 |
+
else:
|
562 |
+
line_farasa = farasa_segmenter.segment(line_input).split()
|
563 |
+
|
564 |
+
segmented_line = []
|
565 |
+
for index, word in enumerate(line_farasa):
|
566 |
+
if word in ["[", "]"]:
|
567 |
+
continue
|
568 |
+
if word in ["رابط", "بريد", "مستخدم"] and line_farasa[index - 1] in [
|
569 |
+
"[",
|
570 |
+
"]",
|
571 |
+
]:
|
572 |
+
segmented_line.append("[" + word + "]")
|
573 |
+
continue
|
574 |
+
segmented_word = []
|
575 |
+
for token in word.split("+"):
|
576 |
+
if token in prefix_list:
|
577 |
+
segmented_word.append(token + "+")
|
578 |
+
elif token in suffix_list:
|
579 |
+
segmented_word.append("+" + token)
|
580 |
+
else:
|
581 |
+
segmented_word.append(token)
|
582 |
+
segmented_line.extend(segmented_word)
|
583 |
+
return " ".join(segmented_line)
|
584 |
+
|
585 |
+
def _remove_non_digit_repetition(self, text: str) -> str:
|
586 |
+
"""
|
587 |
+
:param text: the input text to remove elongation
|
588 |
+
:return: delongated text
|
589 |
+
"""
|
590 |
+
# loop over the number of times the regex matched the text
|
591 |
+
# OLD
|
592 |
+
# for index_ in range(len(re.findall(regex_tatweel, text))):
|
593 |
+
# elongation = re.search(regex_tatweel, text)
|
594 |
+
# if elongation:
|
595 |
+
# elongation_pattern = elongation.group()
|
596 |
+
# elongation_replacement = elongation_pattern[0]
|
597 |
+
# elongation_pattern = re.escape(elongation_pattern)
|
598 |
+
# text = re.sub(
|
599 |
+
# elongation_pattern, elongation_replacement, text, flags=re.MULTILINE
|
600 |
+
# )
|
601 |
+
# else:
|
602 |
+
# break
|
603 |
+
|
604 |
+
# New
|
605 |
+
text = multiple_char_pattern.sub(r"\1\1", text)
|
606 |
+
return text
|
607 |
+
|
608 |
+
def _remove_redundant_punct(self, text: str) -> str:
|
609 |
+
text_ = text
|
610 |
+
result = re.search(redundant_punct_pattern, text)
|
611 |
+
dif = 0
|
612 |
+
while result:
|
613 |
+
sub = result.group()
|
614 |
+
sub = sorted(set(sub), key=sub.index)
|
615 |
+
sub = " " + "".join(list(sub)) + " "
|
616 |
+
text = "".join(
|
617 |
+
(text[: result.span()[0] + dif], sub, text[result.span()[1] + dif :])
|
618 |
+
)
|
619 |
+
text_ = "".join(
|
620 |
+
(text_[: result.span()[0]], text_[result.span()[1] :])
|
621 |
+
).strip()
|
622 |
+
dif = abs(len(text) - len(text_))
|
623 |
+
result = re.search(redundant_punct_pattern, text_)
|
624 |
+
text = re.sub(r"\s+", " ", text)
|
625 |
+
return text.strip()
|
626 |
+
|
627 |
+
|
628 |
+
prefix_list = [
|
629 |
+
"ال",
|
630 |
+
"و",
|
631 |
+
"ف",
|
632 |
+
"ب",
|
633 |
+
"ك",
|
634 |
+
"ل",
|
635 |
+
"لل",
|
636 |
+
"\u0627\u0644",
|
637 |
+
"\u0648",
|
638 |
+
"\u0641",
|
639 |
+
"\u0628",
|
640 |
+
"\u0643",
|
641 |
+
"\u0644",
|
642 |
+
"\u0644\u0644",
|
643 |
+
"س",
|
644 |
+
]
|
645 |
+
suffix_list = [
|
646 |
+
"ه",
|
647 |
+
"ها",
|
648 |
+
"ك",
|
649 |
+
"ي",
|
650 |
+
"هما",
|
651 |
+
"كما",
|
652 |
+
"نا",
|
653 |
+
"كم",
|
654 |
+
"هم",
|
655 |
+
"هن",
|
656 |
+
"كن",
|
657 |
+
"ا",
|
658 |
+
"ان",
|
659 |
+
"ين",
|
660 |
+
"ون",
|
661 |
+
"وا",
|
662 |
+
"ات",
|
663 |
+
"ت",
|
664 |
+
"ن",
|
665 |
+
"ة",
|
666 |
+
"\u0647",
|
667 |
+
"\u0647\u0627",
|
668 |
+
"\u0643",
|
669 |
+
"\u064a",
|
670 |
+
"\u0647\u0645\u0627",
|
671 |
+
"\u0643\u0645\u0627",
|
672 |
+
"\u0646\u0627",
|
673 |
+
"\u0643\u0645",
|
674 |
+
"\u0647\u0645",
|
675 |
+
"\u0647\u0646",
|
676 |
+
"\u0643\u0646",
|
677 |
+
"\u0627",
|
678 |
+
"\u0627\u0646",
|
679 |
+
"\u064a\u0646",
|
680 |
+
"\u0648\u0646",
|
681 |
+
"\u0648\u0627",
|
682 |
+
"\u0627\u062a",
|
683 |
+
"\u062a",
|
684 |
+
"\u0646",
|
685 |
+
"\u0629",
|
686 |
+
]
|
687 |
+
other_tokens = ["[رابط]", "[مستخدم]", "[بريد]"]
|
688 |
+
|
689 |
+
# the never_split list is ussed with the transformers library
|
690 |
+
prefix_symbols = [x + "+" for x in prefix_list]
|
691 |
+
suffix_symblos = ["+" + x for x in suffix_list]
|
692 |
+
never_split_tokens = list(set(prefix_symbols + suffix_symblos + other_tokens))
|
693 |
+
|
694 |
+
url_regexes = [
|
695 |
+
r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)",
|
696 |
+
r"@(https?|ftp)://(-\.)?([^\s/?\.#-]+\.?)+(/[^\s]*)?$@iS",
|
697 |
+
r"http[s]?://[a-zA-Z0-9_\-./~\?=%&]+",
|
698 |
+
r"www[a-zA-Z0-9_\-?=%&/.~]+",
|
699 |
+
r"[a-zA-Z]+\.com",
|
700 |
+
r"(?=http)[^\s]+",
|
701 |
+
r"(?=www)[^\s]+",
|
702 |
+
r"://",
|
703 |
+
]
|
704 |
+
user_mention_regex = r"@[\w\d]+"
|
705 |
+
email_regexes = [r"[\w-]+@([\w-]+\.)+[\w-]+", r"\S+@\S+"]
|
706 |
+
redundant_punct_pattern = (
|
707 |
+
r"([!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ【»؛\s+«–…‘]{2,})"
|
708 |
+
)
|
709 |
+
|
710 |
+
regex_tatweel = r"(\D)\1{2,}"
|
711 |
+
multiple_char_pattern = re.compile(r"(\D)\1{2,}", re.DOTALL)
|
712 |
+
|
713 |
+
rejected_chars_regex = r"[^0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘]"
|
714 |
+
rejected_chars_regexv2 = r"[^0-9\u0621-\u063A\u0641-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘/]"
|
715 |
+
|
716 |
+
regex_url_step1 = r"(?=http)[^\s]+"
|
717 |
+
regex_url_step2 = r"(?=www)[^\s]+"
|
718 |
+
regex_url = r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)"
|
719 |
+
regex_mention = r"@[\w\d]+"
|
720 |
+
regex_email = r"\S+@\S+"
|
721 |
+
|
722 |
+
chars_regex = r"0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘"
|
723 |
+
chars_regexv2 = r"0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘/"
|
724 |
+
|
725 |
+
white_spaced_double_quotation_regex = r'\"\s+([^"]+)\s+\"'
|
726 |
+
white_spaced_single_quotation_regex = r"\'\s+([^']+)\s+\'"
|
727 |
+
white_spaced_back_quotation_regex = r"\`\s+([^`]+)\s+\`"
|
728 |
+
white_spaced_em_dash = r"\—\s+([^—]+)\s+\—"
|
729 |
+
|
730 |
+
left_spaced_chars = r" ([\]!#\$%\),\.:;\?}٪’،؟”؛…»·])"
|
731 |
+
right_spaced_chars = r"([\[\(\{“«‘*\~]) "
|
732 |
+
left_and_right_spaced_chars = r" ([\+\-\<\=\>\@\\\^\_\|\–]) "
|
733 |
+
|
734 |
+
hindi_nums = "٠١٢٣٤٥٦٧٨٩"
|
735 |
+
arabic_nums = "0123456789"
|
736 |
+
hindi_to_arabic_map = str.maketrans(hindi_nums, arabic_nums)
|
backend/processor.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import awesome_streamlit as ast
|
3 |
+
from .preprocess import (
|
4 |
+
ArabertPreprocessor,
|
5 |
+
white_spaced_back_quotation_regex,
|
6 |
+
white_spaced_double_quotation_regex,
|
7 |
+
white_spaced_em_dash,
|
8 |
+
white_spaced_single_quotation_regex,
|
9 |
+
left_and_right_spaced_chars,
|
10 |
+
left_spaced_chars,
|
11 |
+
right_spaced_chars,
|
12 |
+
)
|
13 |
+
import re
|
14 |
+
|
15 |
+
MODELS_to_SELECT = [
|
16 |
+
"None",
|
17 |
+
"bert-base-arabertv01",
|
18 |
+
"bert-base-arabert",
|
19 |
+
"bert-base-arabertv02",
|
20 |
+
"bert-base-arabertv2",
|
21 |
+
"bert-large-arabertv02",
|
22 |
+
"bert-large-arabertv2",
|
23 |
+
"araelectra-base",
|
24 |
+
"araelectra-base-discriminator",
|
25 |
+
"araelectra-base-generator",
|
26 |
+
"araelectra-base-artydiqa",
|
27 |
+
"aragpt2-base",
|
28 |
+
"aragpt2-medium",
|
29 |
+
"aragpt2-large",
|
30 |
+
"aragpt2-mega",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
def unpreprocess(text: str) -> str:
|
35 |
+
"""Re-formats the text to a classic format where punctuations, brackets, parenthesis are not seperated by whitespaces.
|
36 |
+
The objective is to make the generated text of any model appear natural and not preprocessed.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
text (:obj:`str`): input text to be un-preprocessed
|
40 |
+
desegment (:obj:`bool`, optional): [whether or not to remove farasa pre-segmentation before]..
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: The unpreprocessed (and possibly Farasa-desegmented) text.
|
44 |
+
"""
|
45 |
+
|
46 |
+
text = desegment(text)
|
47 |
+
|
48 |
+
# removes the spaces around quotation marks ex: i " ate " an apple --> i "ate" an apple
|
49 |
+
# https://stackoverflow.com/a/53436792/5381220
|
50 |
+
text = re.sub(white_spaced_double_quotation_regex, '"' + r"\1" + '"', text)
|
51 |
+
text = re.sub(white_spaced_single_quotation_regex, "'" + r"\1" + "'", text)
|
52 |
+
text = re.sub(white_spaced_back_quotation_regex, "\`" + r"\1" + "\`", text)
|
53 |
+
text = re.sub(white_spaced_back_quotation_regex, "\—" + r"\1" + "\—", text)
|
54 |
+
|
55 |
+
# during generation, sometimes the models don't put a space after the dot, this handles it
|
56 |
+
text = text.replace(".", " . ")
|
57 |
+
text = " ".join(text.split())
|
58 |
+
|
59 |
+
# handle decimals
|
60 |
+
text = re.sub(r"(\d+) \. (\d+)", r"\1.\2", text)
|
61 |
+
text = re.sub(r"(\d+) \, (\d+)", r"\1,\2", text)
|
62 |
+
|
63 |
+
text = re.sub(left_and_right_spaced_chars, r"\1", text)
|
64 |
+
text = re.sub(left_spaced_chars, r"\1", text)
|
65 |
+
text = re.sub(right_spaced_chars, r"\1", text)
|
66 |
+
|
67 |
+
return text
|
68 |
+
|
69 |
+
|
70 |
+
def desegment(text: str) -> str:
|
71 |
+
"""
|
72 |
+
Use this function if sentence tokenization was done using
|
73 |
+
`from arabert.preprocess_arabert import preprocess` with Farasa enabled
|
74 |
+
AraBERT segmentation using Farasa adds a space after the '+' for prefixes,
|
75 |
+
and after before the '+' for suffixes
|
76 |
+
|
77 |
+
Example:
|
78 |
+
>>> desegment('ال+ دراس +ات')
|
79 |
+
الدراسات
|
80 |
+
"""
|
81 |
+
text = text.replace("+ ", "+")
|
82 |
+
text = text.replace(" +", "+")
|
83 |
+
text = " ".join([_desegmentword(word) for word in text.split(" ")])
|
84 |
+
return text
|
85 |
+
|
86 |
+
|
87 |
+
def _desegmentword(orig_word: str) -> str:
|
88 |
+
"""
|
89 |
+
Word segmentor that takes a Farasa Segmented Word and removes the '+' signs
|
90 |
+
|
91 |
+
Example:
|
92 |
+
>>> _desegmentword("ال+يومي+ة")
|
93 |
+
اليومية
|
94 |
+
"""
|
95 |
+
word = orig_word.replace("ل+ال+", "لل")
|
96 |
+
if "ال+ال" not in orig_word:
|
97 |
+
word = word.replace("ل+ال", "لل")
|
98 |
+
word = word.replace("+", "")
|
99 |
+
word = word.replace("للل", "لل")
|
100 |
+
return word
|
101 |
+
|
102 |
+
|
103 |
+
def write():
|
104 |
+
|
105 |
+
st.markdown(
|
106 |
+
"""
|
107 |
+
<h1 style="text-align:left;">Arabic Text Pre-Processor</h1>
|
108 |
+
""",
|
109 |
+
unsafe_allow_html=True,
|
110 |
+
)
|
111 |
+
st.markdown(
|
112 |
+
"""
|
113 |
+
<style>
|
114 |
+
p, div, input, label {
|
115 |
+
text-align: right;
|
116 |
+
}
|
117 |
+
</style>
|
118 |
+
""",
|
119 |
+
unsafe_allow_html=True,
|
120 |
+
)
|
121 |
+
input_text = st.text_input(
|
122 |
+
"Text to Pre-Process",
|
123 |
+
value="ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري",
|
124 |
+
)
|
125 |
+
|
126 |
+
st.sidebar.title("Model Selector")
|
127 |
+
model_selector = st.sidebar.selectbox(
|
128 |
+
"""Select None to enable further filters""", options=MODELS_to_SELECT, index=3
|
129 |
+
)
|
130 |
+
if model_selector == "None":
|
131 |
+
keep_emojis = st.sidebar.checkbox("Keep emojis", False)
|
132 |
+
remove_html_markup = st.sidebar.checkbox("Remove html markup", True)
|
133 |
+
strip_tashkeel = st.sidebar.checkbox("Strip tashkeel", True)
|
134 |
+
replace_urls_emails_mentions = st.sidebar.checkbox(
|
135 |
+
"Replace urls and emails", True
|
136 |
+
)
|
137 |
+
strip_tatweel = st.sidebar.checkbox("Strip tatweel", True)
|
138 |
+
insert_white_spaces = st.sidebar.checkbox("Insert white spaces", True)
|
139 |
+
remove_non_digit_repetition = st.sidebar.checkbox(
|
140 |
+
"Remove non-digit repetition", True
|
141 |
+
)
|
142 |
+
replace_slash_with_dash = st.sidebar.checkbox("Replace slash with dash", None)
|
143 |
+
map_hindi_numbers_to_arabic = st.sidebar.checkbox(
|
144 |
+
"Map hindi numbers to arabic", None
|
145 |
+
)
|
146 |
+
apply_farasa_segmentation = st.sidebar.checkbox(
|
147 |
+
"Apply farasa segmentation", None
|
148 |
+
)
|
149 |
+
|
150 |
+
run_preprocessor = st.button("Run Pre-Processor")
|
151 |
+
|
152 |
+
prep_text = None
|
153 |
+
if run_preprocessor:
|
154 |
+
if model_selector == "None":
|
155 |
+
arabert_preprocessor = ArabertPreprocessor(
|
156 |
+
model_selector,
|
157 |
+
keep_emojis,
|
158 |
+
remove_html_markup,
|
159 |
+
replace_urls_emails_mentions,
|
160 |
+
strip_tashkeel,
|
161 |
+
strip_tatweel,
|
162 |
+
insert_white_spaces,
|
163 |
+
remove_non_digit_repetition,
|
164 |
+
replace_slash_with_dash,
|
165 |
+
map_hindi_numbers_to_arabic,
|
166 |
+
apply_farasa_segmentation,
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
arabert_preprocessor = ArabertPreprocessor(model_name=model_selector)
|
170 |
+
prep_text = arabert_preprocessor._preprocess_v3(input_text)
|
171 |
+
st.write(prep_text)
|
172 |
+
|
173 |
+
st.write("-----")
|
174 |
+
input_text_unprep = st.text_input(
|
175 |
+
"Text to Undo the Pre-Processing",
|
176 |
+
value=prep_text
|
177 |
+
if prep_text
|
178 |
+
else "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري",
|
179 |
+
)
|
180 |
+
run_unpreprocessor = st.button("Run Un-Pre-Processor")
|
181 |
+
|
182 |
+
if run_unpreprocessor:
|
183 |
+
st.write(unpreprocess(input_text_unprep))
|
backend/qa.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from .qa_utils import annotate_answer
|
4 |
+
from .services import get_qa_answers
|
5 |
+
|
6 |
+
|
7 |
+
def write():
|
8 |
+
_, col1, _ = st.columns(3)
|
9 |
+
|
10 |
+
with col1:
|
11 |
+
st.image("images/is2alni_logo.png", width=200)
|
12 |
+
st.title("إسألني أي شيء")
|
13 |
+
|
14 |
+
st.markdown(
|
15 |
+
"""
|
16 |
+
<style>
|
17 |
+
p, div, input, label {
|
18 |
+
text-align: right;
|
19 |
+
}
|
20 |
+
</style>
|
21 |
+
""",
|
22 |
+
unsafe_allow_html=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
st.sidebar.header("Info")
|
26 |
+
st.sidebar.image("images/AraELECTRA.png", width=150)
|
27 |
+
st.sidebar.write("Powered by [AraELECTRA](https://github.com/aub-mind/arabert)")
|
28 |
+
|
29 |
+
st.sidebar.write("\n")
|
30 |
+
n_answers = st.sidebar.slider(
|
31 |
+
"Max. number of answers", min_value=1, max_value=10, value=2, step=1
|
32 |
+
)
|
33 |
+
|
34 |
+
question = st.text_input("", value="من هو جو بايدن؟")
|
35 |
+
if "؟" not in question:
|
36 |
+
question += "؟"
|
37 |
+
|
38 |
+
run_query = st.button("أجب")
|
39 |
+
if run_query:
|
40 |
+
# https://discuss.streamlit.io/t/showing-a-gif-while-st-spinner-runs/5084
|
41 |
+
with st.spinner("... جاري البحث "):
|
42 |
+
results_dict = get_qa_answers(question)
|
43 |
+
|
44 |
+
if len(results_dict) > 0:
|
45 |
+
st.write("## :الأجابات هي")
|
46 |
+
for result in results_dict["results"][:n_answers]:
|
47 |
+
annotate_answer(result)
|
48 |
+
f"[**المصدر**](<{result['link']}>)"
|
49 |
+
else:
|
50 |
+
st.write("## 😞 ليس لدي جواب")
|
backend/qa_utils.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit.components.v1
|
2 |
+
|
3 |
+
from htbuilder import HtmlElement, div, span, styles
|
4 |
+
from htbuilder.units import px, rem, em
|
5 |
+
|
6 |
+
|
7 |
+
def annotation(body, label="", background="#ddd", color="#333", **style):
|
8 |
+
"""Build an HtmlElement span object with the given body and annotation label.
|
9 |
+
|
10 |
+
The end result will look something like this:
|
11 |
+
|
12 |
+
[body | label]
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
body : string
|
17 |
+
The string to put in the "body" part of the annotation.
|
18 |
+
label : string
|
19 |
+
The string to put in the "label" part of the annotation.
|
20 |
+
background : string
|
21 |
+
The color to use for the background "chip" containing this annotation.
|
22 |
+
color : string
|
23 |
+
The color to use for the body and label text.
|
24 |
+
**style : dict
|
25 |
+
Any CSS you want to use to customize the containing "chip".
|
26 |
+
|
27 |
+
Examples
|
28 |
+
--------
|
29 |
+
|
30 |
+
Produce a simple annotation with default colors:
|
31 |
+
|
32 |
+
>>> annotation("apple", "fruit")
|
33 |
+
|
34 |
+
Produce an annotation with custom colors:
|
35 |
+
|
36 |
+
>>> annotation("apple", "fruit", background="#FF0", color="black")
|
37 |
+
|
38 |
+
Produce an annotation with crazy CSS:
|
39 |
+
|
40 |
+
>>> annotation("apple", "fruit", background="#FF0", border="1px dashed red")
|
41 |
+
|
42 |
+
"""
|
43 |
+
|
44 |
+
if "font_family" not in style:
|
45 |
+
style["font_family"] = "sans-serif"
|
46 |
+
|
47 |
+
return span(
|
48 |
+
style=styles(
|
49 |
+
background=background,
|
50 |
+
border_radius=rem(0.33),
|
51 |
+
color=color,
|
52 |
+
padding=(rem(0.17), rem(0.67)),
|
53 |
+
display="inline-flex",
|
54 |
+
justify_content="center",
|
55 |
+
align_items="center",
|
56 |
+
**style,
|
57 |
+
)
|
58 |
+
)(
|
59 |
+
body,
|
60 |
+
span(
|
61 |
+
style=styles(
|
62 |
+
color=color,
|
63 |
+
font_size=em(0.67),
|
64 |
+
opacity=0.5,
|
65 |
+
padding_left=rem(0.5),
|
66 |
+
text_transform="uppercase",
|
67 |
+
margin_bottom=px(-2),
|
68 |
+
)
|
69 |
+
)(label),
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def annotated_text(*args, **kwargs):
|
74 |
+
"""Writes test with annotations into your Streamlit app.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
*args : str, tuple or htbuilder.HtmlElement
|
79 |
+
Arguments can be:
|
80 |
+
- strings, to draw the string as-is on the screen.
|
81 |
+
- tuples of the form (main_text, annotation_text, background, color) where
|
82 |
+
background and foreground colors are optional and should be an CSS-valid string such as
|
83 |
+
"#aabbcc" or "rgb(10, 20, 30)"
|
84 |
+
- HtmlElement objects in case you want to customize the annotations further. In particular,
|
85 |
+
you can import the `annotation()` function from this module to easily produce annotations
|
86 |
+
whose CSS you can customize via keyword arguments.
|
87 |
+
|
88 |
+
Examples
|
89 |
+
--------
|
90 |
+
|
91 |
+
>>> annotated_text(
|
92 |
+
... "This ",
|
93 |
+
... ("is", "verb", "#8ef"),
|
94 |
+
... " some ",
|
95 |
+
... ("annotated", "adj", "#faa"),
|
96 |
+
... ("text", "noun", "#afa"),
|
97 |
+
... " for those of ",
|
98 |
+
... ("you", "pronoun", "#fea"),
|
99 |
+
... " who ",
|
100 |
+
... ("like", "verb", "#8ef"),
|
101 |
+
... " this sort of ",
|
102 |
+
... ("thing", "noun", "#afa"),
|
103 |
+
... )
|
104 |
+
|
105 |
+
>>> annotated_text(
|
106 |
+
... "Hello ",
|
107 |
+
... annotation("world!", "noun", color="#8ef", border="1px dashed red"),
|
108 |
+
... )
|
109 |
+
|
110 |
+
"""
|
111 |
+
out = div(
|
112 |
+
style=styles(
|
113 |
+
font_family="sans-serif",
|
114 |
+
line_height="1.45",
|
115 |
+
font_size=px(16),
|
116 |
+
text_align="right",
|
117 |
+
)
|
118 |
+
)
|
119 |
+
|
120 |
+
for arg in args:
|
121 |
+
if isinstance(arg, str):
|
122 |
+
out(arg)
|
123 |
+
|
124 |
+
elif isinstance(arg, HtmlElement):
|
125 |
+
out(arg)
|
126 |
+
|
127 |
+
elif isinstance(arg, tuple):
|
128 |
+
out(annotation(*arg))
|
129 |
+
|
130 |
+
else:
|
131 |
+
raise Exception("Oh noes!")
|
132 |
+
|
133 |
+
streamlit.components.v1.html(str(out), **kwargs)
|
134 |
+
|
135 |
+
|
136 |
+
def shorten_text(text, n, reverse=False):
|
137 |
+
if text.isspace() or text == "":
|
138 |
+
return text
|
139 |
+
if reverse:
|
140 |
+
text = text[::-1]
|
141 |
+
words = iter(text.split())
|
142 |
+
lines, current = [], next(words)
|
143 |
+
for word in words:
|
144 |
+
if len(current) + 1 + len(word) > n:
|
145 |
+
break
|
146 |
+
else:
|
147 |
+
current += " " + word
|
148 |
+
lines.append(current)
|
149 |
+
if reverse:
|
150 |
+
return lines[0][::-1]
|
151 |
+
return lines[0]
|
152 |
+
|
153 |
+
|
154 |
+
def annotate_answer(result):
|
155 |
+
annotated_text(
|
156 |
+
shorten_text(
|
157 |
+
result["original"][: result["new_start"]],
|
158 |
+
500,
|
159 |
+
reverse=True,
|
160 |
+
),
|
161 |
+
(result["new_answer"], "جواب", "#8ef"),
|
162 |
+
shorten_text(result["original"][result["new_end"] :], 500) + " ...... إلخ",
|
163 |
+
)
|
backend/sa.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from .services import SentimentAnalyzer
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
|
6 |
+
@lru_cache(maxsize=1)
|
7 |
+
def load_text_generator():
|
8 |
+
predictor = SentimentAnalyzer()
|
9 |
+
return predictor
|
10 |
+
|
11 |
+
|
12 |
+
predictor = load_text_generator()
|
13 |
+
|
14 |
+
|
15 |
+
def write():
|
16 |
+
st.markdown(
|
17 |
+
"""
|
18 |
+
# Arabic Sentiment Analysis
|
19 |
+
|
20 |
+
This is a simple sentiment analysis app that uses the prediction kernel from Wissam's (me) submission that won the [Arabic Senitment Analysis competition @ KAUST](https://www.kaggle.com/c/arabic-sentiment-analysis-2021-kaust)
|
21 |
+
"""
|
22 |
+
)
|
23 |
+
if st.checkbox("More info: "):
|
24 |
+
st.markdown(
|
25 |
+
"""
|
26 |
+
### Submission Description:
|
27 |
+
|
28 |
+
My submission is based on an ensemble of 5 models with varying preprocessing, and classifier design. All model variants are built over MARBERT [1], which is a BERT-based model pre-trained on 1B dialectal Arabic tweets.
|
29 |
+
|
30 |
+
For preprocessing, all models shared the following steps:
|
31 |
+
- Replacing user mentions with “USER” and links with “URL”
|
32 |
+
- Replacing the “#” with “HASH”
|
33 |
+
- Removed the underscore character since it is missing the MARBERT vocabulary.
|
34 |
+
- Removed diacritics and elongations (tatweel)
|
35 |
+
- Spacing out emojis
|
36 |
+
|
37 |
+
For classifier design, all models use a dense layer on top of MARBERT unless otherwise specified. Model training is done by hyperparameter grid-search with 5-fold cross-validation with the following search space:
|
38 |
+
- Learning rate: [2e-5,3e-5,4e-5]
|
39 |
+
- Batch size: 128
|
40 |
+
- Maximum sequence length: 64
|
41 |
+
- Epochs: 3 (we select the best epoch for the final prediction)
|
42 |
+
- Warmup ratio: [0,0.1]
|
43 |
+
- Seed: [1,25,42,123,666]
|
44 |
+
|
45 |
+
Model I is a vanilla variant with only the preprocessing steps mention above applied. Model II enhances the emoji representation by replacing OOV emojis with ones that have similar meaning, for example 💊 😷.
|
46 |
+
We noticed the repetitive use of “السلام عليكم” and “ورحمة الله وبركاته” in neutral tweets, especially when users were directing questions to business accounts. This could confuse the classifier, if it encountered these words in a for example a negative tweet, hence in Model III we removed variation of the phrase mentioned before using fuzzy matching algorithms.
|
47 |
+
|
48 |
+
In Model IV, we tried to help the model by appending a sarcasm label to the input. We first trained a separate MARBERT on the ArSarcasm [2] dataset, and then used it to label the training and test sets.
|
49 |
+
|
50 |
+
Model V uses the vanilla preprocessing approach, but instead of a dense layer built on top of MARBERT, we follow the approach detailed by Safaya et.al. [3] which uses a CNN-based classifier instead.
|
51 |
+
|
52 |
+
For the final prediction, we first average the predictions of the 5 models from cross-validation (this is done for each model separately), we then average the results from the 5 model variants. We observed that the distribution of the predicted sentiment classes, doesn’t quite match the true distribution, this is due to the model preferring the neutral class over the positive class. To counter that, we apply what we call Label-Weighted average where during after the final averaging we rescale the score with the following weights 1.57,0.98 and 0.93 for positive, neutral, and negative (note that the weights were determined empirically).
|
53 |
+
|
54 |
+
1- https://aclanthology.org/2021.acl-long.551/
|
55 |
+
|
56 |
+
2- https://github.com/iabufarha/ArSarcasm
|
57 |
+
|
58 |
+
3- https://github.com/alisafaya/OffensEval2020
|
59 |
+
|
60 |
+
|
61 |
+
"""
|
62 |
+
)
|
63 |
+
input_text = st.text_input(
|
64 |
+
"Enter your text here:",
|
65 |
+
)
|
66 |
+
if st.button("Predict"):
|
67 |
+
with st.spinner("Predicting..."):
|
68 |
+
prediction, score, all_score = predictor.predict([input_text])
|
69 |
+
st.write(f"Result: {prediction[0]}")
|
70 |
+
detailed_score = {
|
71 |
+
"Positive": all_score[0][0],
|
72 |
+
"Neutral": all_score[0][1],
|
73 |
+
"Negative": all_score[0][2],
|
74 |
+
}
|
75 |
+
st.write("All scores:")
|
76 |
+
st.write(detailed_score)
|
backend/sa_utils.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from contextlib import contextmanager
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from fuzzysearch import find_near_matches
|
8 |
+
from pyarabic import araby
|
9 |
+
from torch import nn
|
10 |
+
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline
|
11 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
12 |
+
|
13 |
+
from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex
|
14 |
+
|
15 |
+
multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL)
|
16 |
+
|
17 |
+
# ASAD-NEW_AraBERT_PREP-Balanced
|
18 |
+
class NewArabicPreprocessorBalanced(ArabertPreprocessor):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model_name: str,
|
22 |
+
keep_emojis: bool = False,
|
23 |
+
remove_html_markup: bool = True,
|
24 |
+
replace_urls_emails_mentions: bool = True,
|
25 |
+
strip_tashkeel: bool = True,
|
26 |
+
strip_tatweel: bool = True,
|
27 |
+
insert_white_spaces: bool = True,
|
28 |
+
remove_non_digit_repetition: bool = True,
|
29 |
+
replace_slash_with_dash: bool = None,
|
30 |
+
map_hindi_numbers_to_arabic: bool = None,
|
31 |
+
apply_farasa_segmentation: bool = None,
|
32 |
+
):
|
33 |
+
if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
|
34 |
+
keep_emojis = True
|
35 |
+
remove_non_digit_repetition = True
|
36 |
+
super().__init__(
|
37 |
+
model_name=model_name,
|
38 |
+
keep_emojis=keep_emojis,
|
39 |
+
remove_html_markup=remove_html_markup,
|
40 |
+
replace_urls_emails_mentions=replace_urls_emails_mentions,
|
41 |
+
strip_tashkeel=strip_tashkeel,
|
42 |
+
strip_tatweel=strip_tatweel,
|
43 |
+
insert_white_spaces=insert_white_spaces,
|
44 |
+
remove_non_digit_repetition=remove_non_digit_repetition,
|
45 |
+
replace_slash_with_dash=replace_slash_with_dash,
|
46 |
+
map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic,
|
47 |
+
apply_farasa_segmentation=apply_farasa_segmentation,
|
48 |
+
)
|
49 |
+
self.true_model_name = model_name
|
50 |
+
|
51 |
+
def preprocess(self, text):
|
52 |
+
if "UBC-NLP" in self.true_model_name:
|
53 |
+
return self.ubc_prep(text)
|
54 |
+
|
55 |
+
def ubc_prep(self, text):
|
56 |
+
text = re.sub("\s", " ", text)
|
57 |
+
text = text.replace("\\n", " ")
|
58 |
+
text = text.replace("\\r", " ")
|
59 |
+
text = araby.strip_tashkeel(text)
|
60 |
+
text = araby.strip_tatweel(text)
|
61 |
+
# replace all possible URLs
|
62 |
+
for reg in url_regexes:
|
63 |
+
text = re.sub(reg, " URL ", text)
|
64 |
+
text = re.sub("(URL\s*)+", " URL ", text)
|
65 |
+
# replace mentions with USER
|
66 |
+
text = re.sub(user_mention_regex, " USER ", text)
|
67 |
+
text = re.sub("(USER\s*)+", " USER ", text)
|
68 |
+
# replace hashtags with HASHTAG
|
69 |
+
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
|
70 |
+
text = text.replace("#", " HASH ")
|
71 |
+
text = text.replace("_", " ")
|
72 |
+
text = " ".join(text.split())
|
73 |
+
# text = re.sub("\B\\[Uu]\w+", "", text)
|
74 |
+
text = text.replace("\\U0001f97a", "🥺")
|
75 |
+
text = text.replace("\\U0001f928", "🤨")
|
76 |
+
text = text.replace("\\U0001f9d8", "😀")
|
77 |
+
text = text.replace("\\U0001f975", "😥")
|
78 |
+
text = text.replace("\\U0001f92f", "😲")
|
79 |
+
text = text.replace("\\U0001f92d", "🤭")
|
80 |
+
text = text.replace("\\U0001f9d1", "😐")
|
81 |
+
text = text.replace("\\U000e0067", "")
|
82 |
+
text = text.replace("\\U000e006e", "")
|
83 |
+
text = text.replace("\\U0001f90d", "♥")
|
84 |
+
text = text.replace("\\U0001f973", "🎉")
|
85 |
+
text = text.replace("\\U0001fa79", "")
|
86 |
+
text = text.replace("\\U0001f92b", "🤐")
|
87 |
+
text = text.replace("\\U0001f9da", "🦋")
|
88 |
+
text = text.replace("\\U0001f90e", "♥")
|
89 |
+
text = text.replace("\\U0001f9d0", "🧐")
|
90 |
+
text = text.replace("\\U0001f9cf", "")
|
91 |
+
text = text.replace("\\U0001f92c", "😠")
|
92 |
+
text = text.replace("\\U0001f9f8", "😸")
|
93 |
+
text = text.replace("\\U0001f9b6", "💩")
|
94 |
+
text = text.replace("\\U0001f932", "🤲")
|
95 |
+
text = text.replace("\\U0001f9e1", "🧡")
|
96 |
+
text = text.replace("\\U0001f974", "☹")
|
97 |
+
text = text.replace("\\U0001f91f", "")
|
98 |
+
text = text.replace("\\U0001f9fb", "💩")
|
99 |
+
text = text.replace("\\U0001f92a", "🤪")
|
100 |
+
text = text.replace("\\U0001f9fc", "")
|
101 |
+
text = text.replace("\\U000e0065", "")
|
102 |
+
text = text.replace("\\U0001f92e", "💩")
|
103 |
+
text = text.replace("\\U000e007f", "")
|
104 |
+
text = text.replace("\\U0001f970", "🥰")
|
105 |
+
text = text.replace("\\U0001f929", "🤩")
|
106 |
+
text = text.replace("\\U0001f6f9", "")
|
107 |
+
text = text.replace("🤍", "♥")
|
108 |
+
text = text.replace("🦠", "😷")
|
109 |
+
text = text.replace("🤢", "مقرف")
|
110 |
+
text = text.replace("🤮", "مقرف")
|
111 |
+
text = text.replace("🕠", "⌚")
|
112 |
+
text = text.replace("🤬", "😠")
|
113 |
+
text = text.replace("🤧", "😷")
|
114 |
+
text = text.replace("🥳", "🎉")
|
115 |
+
text = text.replace("🥵", "🔥")
|
116 |
+
text = text.replace("🥴", "☹")
|
117 |
+
text = text.replace("🤫", "🤐")
|
118 |
+
text = text.replace("🤥", "كذاب")
|
119 |
+
text = text.replace("\\u200d", " ")
|
120 |
+
text = text.replace("u200d", " ")
|
121 |
+
text = text.replace("\\u200c", " ")
|
122 |
+
text = text.replace("u200c", " ")
|
123 |
+
text = text.replace('"', "'")
|
124 |
+
text = text.replace("\\xa0", "")
|
125 |
+
text = text.replace("\\u2066", " ")
|
126 |
+
text = re.sub("\B\\\[Uu]\w+", "", text)
|
127 |
+
text = super(NewArabicPreprocessorBalanced, self).preprocess(text)
|
128 |
+
|
129 |
+
text = " ".join(text.split())
|
130 |
+
return text
|
131 |
+
|
132 |
+
|
133 |
+
"""CNNMarbertArabicPreprocessor"""
|
134 |
+
# ASAD-CNN_MARBERT
|
135 |
+
class CNNMarbertArabicPreprocessor(ArabertPreprocessor):
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
model_name,
|
139 |
+
keep_emojis=False,
|
140 |
+
remove_html_markup=True,
|
141 |
+
replace_urls_emails_mentions=True,
|
142 |
+
remove_elongations=True,
|
143 |
+
):
|
144 |
+
if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
|
145 |
+
keep_emojis = True
|
146 |
+
remove_elongations = False
|
147 |
+
super().__init__(
|
148 |
+
model_name,
|
149 |
+
keep_emojis,
|
150 |
+
remove_html_markup,
|
151 |
+
replace_urls_emails_mentions,
|
152 |
+
remove_elongations,
|
153 |
+
)
|
154 |
+
self.true_model_name = model_name
|
155 |
+
|
156 |
+
def preprocess(self, text):
|
157 |
+
if "UBC-NLP" in self.true_model_name:
|
158 |
+
return self.ubc_prep(text)
|
159 |
+
|
160 |
+
def ubc_prep(self, text):
|
161 |
+
text = re.sub("\s", " ", text)
|
162 |
+
text = text.replace("\\n", " ")
|
163 |
+
text = araby.strip_tashkeel(text)
|
164 |
+
text = araby.strip_tatweel(text)
|
165 |
+
# replace all possible URLs
|
166 |
+
for reg in url_regexes:
|
167 |
+
text = re.sub(reg, " URL ", text)
|
168 |
+
text = re.sub("(URL\s*)+", " URL ", text)
|
169 |
+
# replace mentions with USER
|
170 |
+
text = re.sub(user_mention_regex, " USER ", text)
|
171 |
+
text = re.sub("(USER\s*)+", " USER ", text)
|
172 |
+
# replace hashtags with HASHTAG
|
173 |
+
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
|
174 |
+
text = text.replace("#", " HASH ")
|
175 |
+
text = text.replace("_", " ")
|
176 |
+
text = " ".join(text.split())
|
177 |
+
text = super(CNNMarbertArabicPreprocessor, self).preprocess(text)
|
178 |
+
text = text.replace("\u200d", " ")
|
179 |
+
text = text.replace("u200d", " ")
|
180 |
+
text = text.replace("\u200c", " ")
|
181 |
+
text = text.replace("u200c", " ")
|
182 |
+
text = text.replace('"', "'")
|
183 |
+
# text = re.sub('[\d\.]+', ' NUM ', text)
|
184 |
+
# text = re.sub('(NUM\s*)+', ' NUM ', text)
|
185 |
+
text = multiple_char_pattern.sub(r"\1\1", text)
|
186 |
+
text = " ".join(text.split())
|
187 |
+
return text
|
188 |
+
|
189 |
+
|
190 |
+
"""Trial5ArabicPreprocessor"""
|
191 |
+
|
192 |
+
|
193 |
+
class Trial5ArabicPreprocessor(ArabertPreprocessor):
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
model_name,
|
197 |
+
keep_emojis=False,
|
198 |
+
remove_html_markup=True,
|
199 |
+
replace_urls_emails_mentions=True,
|
200 |
+
):
|
201 |
+
if "UBC-NLP" in model_name:
|
202 |
+
keep_emojis = True
|
203 |
+
super().__init__(
|
204 |
+
model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
|
205 |
+
)
|
206 |
+
self.true_model_name = model_name
|
207 |
+
|
208 |
+
def preprocess(self, text):
|
209 |
+
if "UBC-NLP" in self.true_model_name:
|
210 |
+
return self.ubc_prep(text)
|
211 |
+
|
212 |
+
def ubc_prep(self, text):
|
213 |
+
text = re.sub("\s", " ", text)
|
214 |
+
text = text.replace("\\n", " ")
|
215 |
+
text = araby.strip_tashkeel(text)
|
216 |
+
text = araby.strip_tatweel(text)
|
217 |
+
# replace all possible URLs
|
218 |
+
for reg in url_regexes:
|
219 |
+
text = re.sub(reg, " URL ", text)
|
220 |
+
# replace mentions with USER
|
221 |
+
text = re.sub(user_mention_regex, " USER ", text)
|
222 |
+
# replace hashtags with HASHTAG
|
223 |
+
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
|
224 |
+
text = text.replace("#", " HASH TAG ")
|
225 |
+
text = text.replace("_", " ")
|
226 |
+
text = " ".join(text.split())
|
227 |
+
text = super(Trial5ArabicPreprocessor, self).preprocess(text)
|
228 |
+
# text = text.replace("السلام عليكم"," ")
|
229 |
+
# text = text.replace(find_near_matches("السلام عليكم",text,max_deletions=3,max_l_dist=3)[0].matched," ")
|
230 |
+
return text
|
231 |
+
|
232 |
+
|
233 |
+
"""SarcasmArabicPreprocessor"""
|
234 |
+
|
235 |
+
|
236 |
+
class SarcasmArabicPreprocessor(ArabertPreprocessor):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
model_name,
|
240 |
+
keep_emojis=False,
|
241 |
+
remove_html_markup=True,
|
242 |
+
replace_urls_emails_mentions=True,
|
243 |
+
):
|
244 |
+
if "UBC-NLP" in model_name:
|
245 |
+
keep_emojis = True
|
246 |
+
super().__init__(
|
247 |
+
model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
|
248 |
+
)
|
249 |
+
self.true_model_name = model_name
|
250 |
+
|
251 |
+
def preprocess(self, text):
|
252 |
+
if "UBC-NLP" in self.true_model_name:
|
253 |
+
return self.ubc_prep(text)
|
254 |
+
else:
|
255 |
+
return super(SarcasmArabicPreprocessor, self).preprocess(text)
|
256 |
+
|
257 |
+
def ubc_prep(self, text):
|
258 |
+
text = re.sub("\s", " ", text)
|
259 |
+
text = text.replace("\\n", " ")
|
260 |
+
text = araby.strip_tashkeel(text)
|
261 |
+
text = araby.strip_tatweel(text)
|
262 |
+
# replace all possible URLs
|
263 |
+
for reg in url_regexes:
|
264 |
+
text = re.sub(reg, " URL ", text)
|
265 |
+
# replace mentions with USER
|
266 |
+
text = re.sub(user_mention_regex, " USER ", text)
|
267 |
+
# replace hashtags with HASHTAG
|
268 |
+
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
|
269 |
+
text = text.replace("#", " HASH TAG ")
|
270 |
+
text = text.replace("_", " ")
|
271 |
+
text = text.replace('"', " ")
|
272 |
+
text = " ".join(text.split())
|
273 |
+
text = super(SarcasmArabicPreprocessor, self).preprocess(text)
|
274 |
+
return text
|
275 |
+
|
276 |
+
|
277 |
+
"""NoAOAArabicPreprocessor"""
|
278 |
+
|
279 |
+
|
280 |
+
class NoAOAArabicPreprocessor(ArabertPreprocessor):
|
281 |
+
def __init__(
|
282 |
+
self,
|
283 |
+
model_name,
|
284 |
+
keep_emojis=False,
|
285 |
+
remove_html_markup=True,
|
286 |
+
replace_urls_emails_mentions=True,
|
287 |
+
):
|
288 |
+
if "UBC-NLP" in model_name:
|
289 |
+
keep_emojis = True
|
290 |
+
super().__init__(
|
291 |
+
model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
|
292 |
+
)
|
293 |
+
self.true_model_name = model_name
|
294 |
+
|
295 |
+
def preprocess(self, text):
|
296 |
+
if "UBC-NLP" in self.true_model_name:
|
297 |
+
return self.ubc_prep(text)
|
298 |
+
else:
|
299 |
+
return super(NoAOAArabicPreprocessor, self).preprocess(text)
|
300 |
+
|
301 |
+
def ubc_prep(self, text):
|
302 |
+
text = re.sub("\s", " ", text)
|
303 |
+
text = text.replace("\\n", " ")
|
304 |
+
text = araby.strip_tashkeel(text)
|
305 |
+
text = araby.strip_tatweel(text)
|
306 |
+
# replace all possible URLs
|
307 |
+
for reg in url_regexes:
|
308 |
+
text = re.sub(reg, " URL ", text)
|
309 |
+
# replace mentions with USER
|
310 |
+
text = re.sub(user_mention_regex, " USER ", text)
|
311 |
+
# replace hashtags with HASHTAG
|
312 |
+
# text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
|
313 |
+
text = text.replace("#", " HASH TAG ")
|
314 |
+
text = text.replace("_", " ")
|
315 |
+
text = " ".join(text.split())
|
316 |
+
text = super(NoAOAArabicPreprocessor, self).preprocess(text)
|
317 |
+
text = text.replace("السلام عليكم", " ")
|
318 |
+
text = text.replace("ورحمة الله وبركاته", " ")
|
319 |
+
matched = find_near_matches("السلام عليكم", text, max_deletions=3, max_l_dist=3)
|
320 |
+
if len(matched) > 0:
|
321 |
+
text = text.replace(matched[0].matched, " ")
|
322 |
+
matched = find_near_matches(
|
323 |
+
"ورحمة الله وبركاته", text, max_deletions=3, max_l_dist=3
|
324 |
+
)
|
325 |
+
if len(matched) > 0:
|
326 |
+
text = text.replace(matched[0].matched, " ")
|
327 |
+
return text
|
328 |
+
|
329 |
+
|
330 |
+
class CnnBertForSequenceClassification(BertPreTrainedModel):
|
331 |
+
def __init__(self, config):
|
332 |
+
super().__init__(config)
|
333 |
+
self.num_labels = config.num_labels
|
334 |
+
self.config = config
|
335 |
+
|
336 |
+
self.bert = BertModel(config)
|
337 |
+
|
338 |
+
filter_sizes = [1, 2, 3, 4, 5]
|
339 |
+
num_filters = 32
|
340 |
+
self.convs1 = nn.ModuleList(
|
341 |
+
[nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes]
|
342 |
+
)
|
343 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
344 |
+
self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels)
|
345 |
+
|
346 |
+
self.init_weights()
|
347 |
+
|
348 |
+
def forward(
|
349 |
+
self,
|
350 |
+
input_ids=None,
|
351 |
+
attention_mask=None,
|
352 |
+
token_type_ids=None,
|
353 |
+
position_ids=None,
|
354 |
+
head_mask=None,
|
355 |
+
inputs_embeds=None,
|
356 |
+
labels=None,
|
357 |
+
output_attentions=None,
|
358 |
+
output_hidden_states=None,
|
359 |
+
return_dict=None,
|
360 |
+
):
|
361 |
+
|
362 |
+
return_dict = (
|
363 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
364 |
+
)
|
365 |
+
|
366 |
+
outputs = self.bert(
|
367 |
+
input_ids,
|
368 |
+
attention_mask=attention_mask,
|
369 |
+
token_type_ids=token_type_ids,
|
370 |
+
position_ids=position_ids,
|
371 |
+
head_mask=head_mask,
|
372 |
+
inputs_embeds=inputs_embeds,
|
373 |
+
output_attentions=output_attentions,
|
374 |
+
output_hidden_states=output_hidden_states,
|
375 |
+
return_dict=return_dict,
|
376 |
+
)
|
377 |
+
|
378 |
+
x = outputs[2][-4:]
|
379 |
+
|
380 |
+
x = torch.stack(x, dim=1)
|
381 |
+
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
|
382 |
+
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
|
383 |
+
x = torch.cat(x, 1)
|
384 |
+
x = self.dropout(x)
|
385 |
+
logits = self.classifier(x)
|
386 |
+
|
387 |
+
loss = None
|
388 |
+
if labels is not None:
|
389 |
+
if self.config.problem_type is None:
|
390 |
+
if self.num_labels == 1:
|
391 |
+
self.config.problem_type = "regression"
|
392 |
+
elif self.num_labels > 1 and (
|
393 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
394 |
+
):
|
395 |
+
self.config.problem_type = "single_label_classification"
|
396 |
+
else:
|
397 |
+
self.config.problem_type = "multi_label_classification"
|
398 |
+
|
399 |
+
if self.config.problem_type == "regression":
|
400 |
+
loss_fct = nn.MSELoss()
|
401 |
+
if self.num_labels == 1:
|
402 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
403 |
+
else:
|
404 |
+
loss = loss_fct(logits, labels)
|
405 |
+
elif self.config.problem_type == "single_label_classification":
|
406 |
+
loss_fct = nn.CrossEntropyLoss()
|
407 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
408 |
+
elif self.config.problem_type == "multi_label_classification":
|
409 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
410 |
+
loss = loss_fct(logits, labels)
|
411 |
+
if not return_dict:
|
412 |
+
output = (logits,) + outputs[2:]
|
413 |
+
return ((loss,) + output) if loss is not None else output
|
414 |
+
|
415 |
+
return SequenceClassifierOutput(
|
416 |
+
loss=loss,
|
417 |
+
logits=logits,
|
418 |
+
hidden_states=None,
|
419 |
+
attentions=outputs.attentions,
|
420 |
+
)
|
421 |
+
|
422 |
+
|
423 |
+
class CNNTextClassificationPipeline:
|
424 |
+
def __init__(self, model_path, device, return_all_scores=False):
|
425 |
+
self.model_path = model_path
|
426 |
+
self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path)
|
427 |
+
# Special handling
|
428 |
+
self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
|
429 |
+
if self.device.type == "cuda":
|
430 |
+
self.model = self.model.to(self.device)
|
431 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
432 |
+
self.return_all_scores = return_all_scores
|
433 |
+
|
434 |
+
@contextmanager
|
435 |
+
def device_placement(self):
|
436 |
+
"""
|
437 |
+
Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
|
438 |
+
Returns:
|
439 |
+
Context manager
|
440 |
+
Examples::
|
441 |
+
# Explicitly ask for tensor allocation on CUDA device :0
|
442 |
+
pipe = pipeline(..., device=0)
|
443 |
+
with pipe.device_placement():
|
444 |
+
# Every framework specific tensor allocation will be done on the request device
|
445 |
+
output = pipe(...)
|
446 |
+
"""
|
447 |
+
|
448 |
+
if self.device.type == "cuda":
|
449 |
+
torch.cuda.set_device(self.device)
|
450 |
+
|
451 |
+
yield
|
452 |
+
|
453 |
+
def ensure_tensor_on_device(self, **inputs):
|
454 |
+
"""
|
455 |
+
Ensure PyTorch tensors are on the specified device.
|
456 |
+
Args:
|
457 |
+
inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.
|
458 |
+
Return:
|
459 |
+
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
|
460 |
+
"""
|
461 |
+
return {
|
462 |
+
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
|
463 |
+
for name, tensor in inputs.items()
|
464 |
+
}
|
465 |
+
|
466 |
+
def __call__(self, text):
|
467 |
+
"""
|
468 |
+
Classify the text(s) given as inputs.
|
469 |
+
Args:
|
470 |
+
args (:obj:`str` or :obj:`List[str]`):
|
471 |
+
One or several texts (or one list of prompts) to classify.
|
472 |
+
Return:
|
473 |
+
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
|
474 |
+
- **label** (:obj:`str`) -- The label predicted.
|
475 |
+
- **score** (:obj:`float`) -- The corresponding probability.
|
476 |
+
If ``self.return_all_scores=True``, one such dictionary is returned per label.
|
477 |
+
"""
|
478 |
+
# outputs = super().__call__(*args, **kwargs)
|
479 |
+
inputs = self.tokenizer.batch_encode_plus(
|
480 |
+
text,
|
481 |
+
add_special_tokens=True,
|
482 |
+
max_length=64,
|
483 |
+
padding=True,
|
484 |
+
truncation="longest_first",
|
485 |
+
return_tensors="pt",
|
486 |
+
)
|
487 |
+
|
488 |
+
with torch.no_grad():
|
489 |
+
inputs = self.ensure_tensor_on_device(**inputs)
|
490 |
+
predictions = self.model(**inputs)[0].cpu()
|
491 |
+
|
492 |
+
predictions = predictions.numpy()
|
493 |
+
|
494 |
+
if self.model.config.num_labels == 1:
|
495 |
+
scores = 1.0 / (1.0 + np.exp(-predictions))
|
496 |
+
else:
|
497 |
+
scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
|
498 |
+
if self.return_all_scores:
|
499 |
+
return [
|
500 |
+
[
|
501 |
+
{"label": self.model.config.id2label[i], "score": score.item()}
|
502 |
+
for i, score in enumerate(item)
|
503 |
+
]
|
504 |
+
for item in scores
|
505 |
+
]
|
506 |
+
else:
|
507 |
+
return [
|
508 |
+
{"label": self.inv_label_map[item.argmax()], "score": item.max().item()}
|
509 |
+
for item in scores
|
510 |
+
]
|
backend/sarcasm.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from .sa import predictor
|
3 |
+
|
4 |
+
|
5 |
+
def write():
|
6 |
+
st.markdown(
|
7 |
+
"""
|
8 |
+
# Arabic Sarcasm Detection
|
9 |
+
|
10 |
+
This is a simple sarcasm detection app that uses the [MARBERT](https://huggingface.co/UBC-NLP/MARBERT) model trained on [ArSarcasm](https://github.com/iabufarha/ArSarcasm)
|
11 |
+
"""
|
12 |
+
)
|
13 |
+
|
14 |
+
input_text = st.text_input(
|
15 |
+
"Enter your text here:",
|
16 |
+
)
|
17 |
+
if st.button("Predict"):
|
18 |
+
with st.spinner("Predicting..."):
|
19 |
+
prediction, scores = predictor.get_preds_from_sarcasm([input_text])
|
20 |
+
st.write(f"Result: {prediction[0]}")
|
21 |
+
st.write(f"Score: {scores[0]}")
|
backend/services.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import List
|
6 |
+
from urllib.parse import unquote
|
7 |
+
|
8 |
+
import more_itertools
|
9 |
+
import pandas as pd
|
10 |
+
import requests
|
11 |
+
import streamlit as st
|
12 |
+
import wikipedia
|
13 |
+
from codetiming import Timer
|
14 |
+
from fuzzysearch import find_near_matches
|
15 |
+
from googleapi import google
|
16 |
+
from tqdm.auto import tqdm
|
17 |
+
from transformers import (
|
18 |
+
AutoTokenizer,
|
19 |
+
GPT2LMHeadModel,
|
20 |
+
GPT2Tokenizer,
|
21 |
+
pipeline,
|
22 |
+
set_seed,
|
23 |
+
)
|
24 |
+
|
25 |
+
from .modeling_gpt2 import GPT2LMHeadModel as GROVERLMHeadModel
|
26 |
+
from .preprocess import ArabertPreprocessor
|
27 |
+
from .sa_utils import *
|
28 |
+
from .utils import download_models, softmax
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
# Taken and Modified from https://huggingface.co/spaces/flax-community/chef-transformer/blob/main/app.py
|
32 |
+
class TextGeneration:
|
33 |
+
def __init__(self):
|
34 |
+
self.debug = False
|
35 |
+
self.generation_pipline = {}
|
36 |
+
self.preprocessor = ArabertPreprocessor(model_name="aragpt2-mega")
|
37 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
38 |
+
"aubmindlab/aragpt2-mega", use_fast=False
|
39 |
+
)
|
40 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
41 |
+
self.API_KEY = os.getenv("API_KEY")
|
42 |
+
self.headers = {"Authorization": f"Bearer {self.API_KEY}"}
|
43 |
+
# self.model_names_or_paths = {
|
44 |
+
# "aragpt2-medium": "D:/ML/Models/aragpt2-medium",
|
45 |
+
# "aragpt2-base": "D:/ML/Models/aragpt2-base",
|
46 |
+
# }
|
47 |
+
self.model_names_or_paths = {
|
48 |
+
# "aragpt2-medium": "aubmindlab/aragpt2-medium",
|
49 |
+
"aragpt2-base": "aubmindlab/aragpt2-base",
|
50 |
+
# "aragpt2-large": "aubmindlab/aragpt2-large",
|
51 |
+
"aragpt2-mega": "aubmindlab/aragpt2-mega",
|
52 |
+
}
|
53 |
+
set_seed(42)
|
54 |
+
|
55 |
+
def load_pipeline(self):
|
56 |
+
for model_name, model_path in self.model_names_or_paths.items():
|
57 |
+
if "base" in model_name or "medium" in model_name:
|
58 |
+
self.generation_pipline[model_name] = pipeline(
|
59 |
+
"text-generation",
|
60 |
+
model=GPT2LMHeadModel.from_pretrained(model_path),
|
61 |
+
tokenizer=self.tokenizer,
|
62 |
+
device=-1,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
self.generation_pipline[model_name] = pipeline(
|
66 |
+
"text-generation",
|
67 |
+
model=GROVERLMHeadModel.from_pretrained(model_path),
|
68 |
+
tokenizer=self.tokenizer,
|
69 |
+
device=-1,
|
70 |
+
)
|
71 |
+
|
72 |
+
def load(self):
|
73 |
+
if not self.debug:
|
74 |
+
self.load_pipeline()
|
75 |
+
|
76 |
+
def generate(
|
77 |
+
self,
|
78 |
+
model_name,
|
79 |
+
prompt,
|
80 |
+
max_new_tokens: int,
|
81 |
+
temperature: float,
|
82 |
+
top_k: int,
|
83 |
+
top_p: float,
|
84 |
+
repetition_penalty: float,
|
85 |
+
no_repeat_ngram_size: int,
|
86 |
+
do_sample: bool,
|
87 |
+
num_beams: int,
|
88 |
+
):
|
89 |
+
logger.info(f"Generating with {model_name}")
|
90 |
+
prompt = self.preprocessor.preprocess(prompt)
|
91 |
+
return_full_text = False
|
92 |
+
return_text = True
|
93 |
+
num_return_sequences = 1
|
94 |
+
pad_token_id = 0
|
95 |
+
eos_token_id = 0
|
96 |
+
input_tok = self.tokenizer.tokenize(prompt)
|
97 |
+
max_length = len(input_tok) + max_new_tokens
|
98 |
+
if max_length > 1024:
|
99 |
+
max_length = 1024
|
100 |
+
if not self.debug:
|
101 |
+
generated_text = self.generation_pipline[model_name.lower()](
|
102 |
+
prompt,
|
103 |
+
max_length=max_length,
|
104 |
+
temperature=temperature,
|
105 |
+
top_k=top_k,
|
106 |
+
top_p=top_p,
|
107 |
+
repetition_penalty=repetition_penalty,
|
108 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
109 |
+
pad_token_id=pad_token_id,
|
110 |
+
eos_token_id=eos_token_id,
|
111 |
+
return_full_text=return_full_text,
|
112 |
+
return_text=return_text,
|
113 |
+
do_sample=do_sample,
|
114 |
+
num_beams=num_beams,
|
115 |
+
num_return_sequences=num_return_sequences,
|
116 |
+
)[0]["generated_text"]
|
117 |
+
else:
|
118 |
+
generated_text = self.generate_by_query(
|
119 |
+
prompt,
|
120 |
+
model_name,
|
121 |
+
max_length=max_length,
|
122 |
+
temperature=temperature,
|
123 |
+
top_k=top_k,
|
124 |
+
top_p=top_p,
|
125 |
+
repetition_penalty=repetition_penalty,
|
126 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
127 |
+
pad_token_id=pad_token_id,
|
128 |
+
eos_token_id=eos_token_id,
|
129 |
+
return_full_text=return_full_text,
|
130 |
+
return_text=return_text,
|
131 |
+
do_sample=do_sample,
|
132 |
+
num_beams=num_beams,
|
133 |
+
num_return_sequences=num_return_sequences,
|
134 |
+
)
|
135 |
+
# print(generated_text)
|
136 |
+
if isinstance(generated_text, dict):
|
137 |
+
if "error" in generated_text:
|
138 |
+
if "is currently loading" in generated_text["error"]:
|
139 |
+
return f"Model is currently loading, estimated time is {generated_text['estimated_time']}"
|
140 |
+
return generated_text["error"]
|
141 |
+
else:
|
142 |
+
return "Something happened 🤷♂️!!"
|
143 |
+
else:
|
144 |
+
generated_text = generated_text[0]["generated_text"]
|
145 |
+
|
146 |
+
logger.info(f"Prompt: {prompt}")
|
147 |
+
logger.info(f"Generated text: {generated_text}")
|
148 |
+
return self.preprocessor.unpreprocess(generated_text)
|
149 |
+
|
150 |
+
def query(self, payload, model_name):
|
151 |
+
data = json.dumps(payload)
|
152 |
+
url = (
|
153 |
+
"https://api-inference.huggingface.co/models/aubmindlab/"
|
154 |
+
+ model_name.lower()
|
155 |
+
)
|
156 |
+
response = requests.request("POST", url, headers=self.headers, data=data)
|
157 |
+
return json.loads(response.content.decode("utf-8"))
|
158 |
+
|
159 |
+
def generate_by_query(
|
160 |
+
self,
|
161 |
+
prompt: str,
|
162 |
+
model_name: str,
|
163 |
+
max_length: int,
|
164 |
+
temperature: float,
|
165 |
+
top_k: int,
|
166 |
+
top_p: float,
|
167 |
+
repetition_penalty: float,
|
168 |
+
no_repeat_ngram_size: int,
|
169 |
+
pad_token_id: int,
|
170 |
+
eos_token_id: int,
|
171 |
+
return_full_text: int,
|
172 |
+
return_text: int,
|
173 |
+
do_sample: bool,
|
174 |
+
num_beams: int,
|
175 |
+
num_return_sequences: int,
|
176 |
+
):
|
177 |
+
payload = {
|
178 |
+
"inputs": prompt,
|
179 |
+
"parameters": {
|
180 |
+
"max_length ": max_length,
|
181 |
+
"top_k": top_k,
|
182 |
+
"top_p": top_p,
|
183 |
+
"temperature": temperature,
|
184 |
+
"repetition_penalty": repetition_penalty,
|
185 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
186 |
+
"pad_token_id": pad_token_id,
|
187 |
+
"eos_token_id": eos_token_id,
|
188 |
+
"return_full_text": return_full_text,
|
189 |
+
"return_text": return_text,
|
190 |
+
"pad_token_id": pad_token_id,
|
191 |
+
"do_sample": do_sample,
|
192 |
+
"num_beams": num_beams,
|
193 |
+
"num_return_sequences": num_return_sequences,
|
194 |
+
},
|
195 |
+
"options": {
|
196 |
+
"use_cache": True,
|
197 |
+
},
|
198 |
+
}
|
199 |
+
return self.query(payload, model_name)
|
200 |
+
|
201 |
+
|
202 |
+
class SentimentAnalyzer:
|
203 |
+
def __init__(self):
|
204 |
+
self.sa_models = [
|
205 |
+
"sa_trial5_1",
|
206 |
+
# "sa_no_aoa_in_neutral",
|
207 |
+
# "sa_cnnbert",
|
208 |
+
# "sa_sarcasm",
|
209 |
+
# "sar_trial10",
|
210 |
+
# "sa_no_AOA",
|
211 |
+
]
|
212 |
+
download_models(self.sa_models)
|
213 |
+
# fmt: off
|
214 |
+
self.processors = {
|
215 |
+
"sa_trial5_1": Trial5ArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
216 |
+
# "sa_no_aoa_in_neutral": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
|
217 |
+
# "sa_cnnbert": CNNMarbertArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
218 |
+
# "sa_sarcasm": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
219 |
+
# "sar_trial10": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
220 |
+
# "sa_no_AOA": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
|
221 |
+
}
|
222 |
+
|
223 |
+
self.pipelines = {
|
224 |
+
"sa_trial5_1": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_trial5_1",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_trial5_1")],
|
225 |
+
# "sa_no_aoa_in_neutral": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_aoa_in_neutral",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_aoa_in_neutral")],
|
226 |
+
# "sa_cnnbert": [CNNTextClassificationPipeline("{}/train_{}/best_model".format("sa_cnnbert",i), device=-1, return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_cnnbert")],
|
227 |
+
# "sa_sarcasm": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_sarcasm",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_sarcasm")],
|
228 |
+
# "sar_trial10": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sar_trial10",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sar_trial10")],
|
229 |
+
# "sa_no_AOA": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_AOA",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_AOA")],
|
230 |
+
}
|
231 |
+
# fmt: on
|
232 |
+
|
233 |
+
def get_preds_from_sarcasm(self, texts):
|
234 |
+
prep = self.processors["sar_trial10"]
|
235 |
+
prep_texts = [prep.preprocess(x) for x in texts]
|
236 |
+
|
237 |
+
preds_df = pd.DataFrame([])
|
238 |
+
for i in range(0, 5):
|
239 |
+
preds = []
|
240 |
+
for s in more_itertools.chunked(list(prep_texts), 128):
|
241 |
+
preds.extend(self.pipelines["sar_trial10"][i](s))
|
242 |
+
preds_df[f"model_{i}"] = preds
|
243 |
+
|
244 |
+
final_labels = []
|
245 |
+
final_scores = []
|
246 |
+
for id, row in preds_df.iterrows():
|
247 |
+
pos_total = 0
|
248 |
+
neu_total = 0
|
249 |
+
for pred in row[:]:
|
250 |
+
pos_total += pred[0]["score"]
|
251 |
+
neu_total += pred[1]["score"]
|
252 |
+
|
253 |
+
pos_avg = pos_total / len(row[:])
|
254 |
+
neu_avg = neu_total / len(row[:])
|
255 |
+
|
256 |
+
final_labels.append(
|
257 |
+
self.pipelines["sar_trial10"][0].model.config.id2label[
|
258 |
+
np.argmax([pos_avg, neu_avg])
|
259 |
+
]
|
260 |
+
)
|
261 |
+
final_scores.append(np.max([pos_avg, neu_avg]))
|
262 |
+
|
263 |
+
return final_labels, final_scores
|
264 |
+
|
265 |
+
def get_preds_from_a_model(self, texts: List[str], model_name):
|
266 |
+
try:
|
267 |
+
prep = self.processors[model_name]
|
268 |
+
|
269 |
+
prep_texts = [prep.preprocess(x) for x in texts]
|
270 |
+
if model_name == "sa_sarcasm":
|
271 |
+
sarcasm_label, _ = self.get_preds_from_sarcasm(texts)
|
272 |
+
sarcastic_map = {"Not_Sarcastic": "غير ساخر", "Sarcastic": "ساخر"}
|
273 |
+
labeled_prep_texts = []
|
274 |
+
for t, l in zip(prep_texts, sarcasm_label):
|
275 |
+
labeled_prep_texts.append(sarcastic_map[l] + " [SEP] " + t)
|
276 |
+
|
277 |
+
preds_df = pd.DataFrame([])
|
278 |
+
for i in range(0, 5):
|
279 |
+
preds = []
|
280 |
+
for s in more_itertools.chunked(list(prep_texts), 128):
|
281 |
+
preds.extend(self.pipelines[model_name][i](s))
|
282 |
+
preds_df[f"model_{i}"] = preds
|
283 |
+
|
284 |
+
final_labels = []
|
285 |
+
final_scores = []
|
286 |
+
final_scores_list = []
|
287 |
+
for id, row in preds_df.iterrows():
|
288 |
+
pos_total = 0
|
289 |
+
neg_total = 0
|
290 |
+
neu_total = 0
|
291 |
+
for pred in row[2:]:
|
292 |
+
pos_total += pred[0]["score"]
|
293 |
+
neu_total += pred[1]["score"]
|
294 |
+
neg_total += pred[2]["score"]
|
295 |
+
|
296 |
+
pos_avg = pos_total / 5
|
297 |
+
neu_avg = neu_total / 5
|
298 |
+
neg_avg = neg_total / 5
|
299 |
+
|
300 |
+
if model_name == "sa_no_aoa_in_neutral":
|
301 |
+
final_labels.append(
|
302 |
+
self.pipelines[model_name][0].model.config.id2label[
|
303 |
+
np.argmax([neu_avg, neg_avg, pos_avg])
|
304 |
+
]
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
final_labels.append(
|
308 |
+
self.pipelines[model_name][0].model.config.id2label[
|
309 |
+
np.argmax([pos_avg, neu_avg, neg_avg])
|
310 |
+
]
|
311 |
+
)
|
312 |
+
final_scores.append(np.max([pos_avg, neu_avg, neg_avg]))
|
313 |
+
final_scores_list.append((pos_avg, neu_avg, neg_avg))
|
314 |
+
except RuntimeError as e:
|
315 |
+
if model_name == "sa_cnnbert":
|
316 |
+
return (
|
317 |
+
["Neutral"] * len(texts),
|
318 |
+
[0.0] * len(texts),
|
319 |
+
[(0.0, 0.0, 0.0)] * len(texts),
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
raise RuntimeError(e)
|
323 |
+
return final_labels, final_scores, final_scores_list
|
324 |
+
|
325 |
+
def predict(self, texts: List[str]):
|
326 |
+
logger.info(f"Predicting for: {texts}")
|
327 |
+
# (
|
328 |
+
# new_balanced_label,
|
329 |
+
# new_balanced_score,
|
330 |
+
# new_balanced_score_list,
|
331 |
+
# ) = self.get_preds_from_a_model(texts, "sa_no_aoa_in_neutral")
|
332 |
+
# (
|
333 |
+
# cnn_marbert_label,
|
334 |
+
# cnn_marbert_score,
|
335 |
+
# cnn_marbert_score_list,
|
336 |
+
# ) = self.get_preds_from_a_model(texts, "sa_cnnbert")
|
337 |
+
trial5_label, trial5_score, trial5_score_list = self.get_preds_from_a_model(
|
338 |
+
texts, "sa_trial5_1"
|
339 |
+
)
|
340 |
+
# no_aoa_label, no_aoa_score, no_aoa_score_list = self.get_preds_from_a_model(
|
341 |
+
# texts, "sa_no_AOA"
|
342 |
+
# )
|
343 |
+
# sarcasm_label, sarcasm_score, sarcasm_score_list = self.get_preds_from_a_model(
|
344 |
+
# texts, "sa_sarcasm"
|
345 |
+
# )
|
346 |
+
|
347 |
+
id_label_map = {0: "Positive", 1: "Neutral", 2: "Negative"}
|
348 |
+
|
349 |
+
final_ensemble_prediction = []
|
350 |
+
final_ensemble_score = []
|
351 |
+
final_ensemble_all_score = []
|
352 |
+
for entry in zip(
|
353 |
+
# new_balanced_score_list,
|
354 |
+
# cnn_marbert_score_list,
|
355 |
+
trial5_score_list,
|
356 |
+
# no_aoa_score_list,
|
357 |
+
# sarcasm_score_list,
|
358 |
+
):
|
359 |
+
pos_score = 0
|
360 |
+
neu_score = 0
|
361 |
+
neg_score = 0
|
362 |
+
for s in entry:
|
363 |
+
pos_score += s[0] * 1.57
|
364 |
+
neu_score += s[1] * 0.98
|
365 |
+
neg_score += s[2] * 0.93
|
366 |
+
|
367 |
+
# weighted 2
|
368 |
+
# pos_score += s[0]*1.67
|
369 |
+
# neu_score += s[1]
|
370 |
+
# neg_score += s[2]*0.95
|
371 |
+
|
372 |
+
final_ensemble_prediction.append(
|
373 |
+
id_label_map[np.argmax([pos_score, neu_score, neg_score])]
|
374 |
+
)
|
375 |
+
final_ensemble_score.append(np.max([pos_score, neu_score, neg_score]))
|
376 |
+
final_ensemble_all_score.append(
|
377 |
+
softmax(np.array([pos_score, neu_score, neg_score])).tolist()
|
378 |
+
)
|
379 |
+
|
380 |
+
logger.info(f"Result: {final_ensemble_prediction}")
|
381 |
+
logger.info(f"Score: {final_ensemble_score}")
|
382 |
+
logger.info(f"All Scores: {final_ensemble_all_score}")
|
383 |
+
return final_ensemble_prediction, final_ensemble_score, final_ensemble_all_score
|
384 |
+
|
385 |
+
|
386 |
+
wikipedia.set_lang("ar")
|
387 |
+
|
388 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
389 |
+
|
390 |
+
preprocessor = ArabertPreprocessor("wissamantoun/araelectra-base-artydiqa")
|
391 |
+
logger.info("Loading QA Pipeline...")
|
392 |
+
tokenizer = AutoTokenizer.from_pretrained("wissamantoun/araelectra-base-artydiqa")
|
393 |
+
qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa")
|
394 |
+
logger.info("Finished loading QA Pipeline...")
|
395 |
+
|
396 |
+
|
397 |
+
@lru_cache(maxsize=100)
|
398 |
+
def get_qa_answers(question):
|
399 |
+
logger.info("\n=================================================================")
|
400 |
+
logger.info(f"Question: {question}")
|
401 |
+
|
402 |
+
if "وسام أنطون" in question or "wissam antoun" in question.lower():
|
403 |
+
return {
|
404 |
+
"title": "Creator",
|
405 |
+
"results": [
|
406 |
+
{
|
407 |
+
"score": 1.0,
|
408 |
+
"new_start": 0,
|
409 |
+
"new_end": 12,
|
410 |
+
"new_answer": "My Creator 😜",
|
411 |
+
"original": "My Creator 😜",
|
412 |
+
"link": "https://github.com/WissamAntoun/",
|
413 |
+
}
|
414 |
+
],
|
415 |
+
}
|
416 |
+
search_timer = Timer(
|
417 |
+
"search and wiki", text="Search and Wikipedia Time: {:.2f}", logger=logging.info
|
418 |
+
)
|
419 |
+
try:
|
420 |
+
search_timer.start()
|
421 |
+
search_results = google.search(
|
422 |
+
question + " site:ar.wikipedia.org", lang="ar", area="ar"
|
423 |
+
)
|
424 |
+
if len(search_results) == 0:
|
425 |
+
return {}
|
426 |
+
|
427 |
+
page_name = search_results[0].link.split("wiki/")[-1]
|
428 |
+
wiki_page = wikipedia.page(unquote(page_name))
|
429 |
+
wiki_page_content = wiki_page.content
|
430 |
+
search_timer.stop()
|
431 |
+
except:
|
432 |
+
return {}
|
433 |
+
|
434 |
+
sections = []
|
435 |
+
for section in re.split("== .+ ==[^=]", wiki_page_content):
|
436 |
+
if not section.isspace():
|
437 |
+
prep_section = tokenizer.tokenize(preprocessor.preprocess(section))
|
438 |
+
if len(prep_section) > 500:
|
439 |
+
subsections = []
|
440 |
+
for subsection in re.split("=== .+ ===", section):
|
441 |
+
if subsection.isspace():
|
442 |
+
continue
|
443 |
+
prep_subsection = tokenizer.tokenize(
|
444 |
+
preprocessor.preprocess(subsection)
|
445 |
+
)
|
446 |
+
subsections.append(subsection)
|
447 |
+
# logger.info(f"Subsection found with length: {len(prep_subsection)}")
|
448 |
+
sections.extend(subsections)
|
449 |
+
else:
|
450 |
+
# logger.info(f"Regular Section with length: {len(prep_section)}")
|
451 |
+
sections.append(section)
|
452 |
+
|
453 |
+
full_len_sections = []
|
454 |
+
temp_section = ""
|
455 |
+
for section in sections:
|
456 |
+
if (
|
457 |
+
len(tokenizer.tokenize(preprocessor.preprocess(temp_section)))
|
458 |
+
+ len(tokenizer.tokenize(preprocessor.preprocess(section)))
|
459 |
+
> 384
|
460 |
+
):
|
461 |
+
if temp_section == "":
|
462 |
+
temp_section = section
|
463 |
+
continue
|
464 |
+
full_len_sections.append(temp_section)
|
465 |
+
# logger.info(
|
466 |
+
# f"full section length: {len(tokenizer.tokenize(preprocessor.preprocess(temp_section)))}"
|
467 |
+
# )
|
468 |
+
temp_section = ""
|
469 |
+
else:
|
470 |
+
temp_section += " " + section + " "
|
471 |
+
if temp_section != "":
|
472 |
+
full_len_sections.append(temp_section)
|
473 |
+
|
474 |
+
reader_time = Timer("electra", text="Reader Time: {:.2f}", logger=logging.info)
|
475 |
+
reader_time.start()
|
476 |
+
results = qa_pipe(
|
477 |
+
question=[preprocessor.preprocess(question)] * len(full_len_sections),
|
478 |
+
context=[preprocessor.preprocess(x) for x in full_len_sections],
|
479 |
+
)
|
480 |
+
|
481 |
+
if not isinstance(results, list):
|
482 |
+
results = [results]
|
483 |
+
|
484 |
+
logger.info(f"Wiki Title: {unquote(page_name)}")
|
485 |
+
logger.info(f"Total Sections: {len(sections)}")
|
486 |
+
logger.info(f"Total Full Sections: {len(full_len_sections)}")
|
487 |
+
|
488 |
+
for result, section in zip(results, full_len_sections):
|
489 |
+
result["original"] = section
|
490 |
+
answer_match = find_near_matches(
|
491 |
+
" " + preprocessor.unpreprocess(result["answer"]) + " ",
|
492 |
+
result["original"],
|
493 |
+
max_l_dist=min(5, len(preprocessor.unpreprocess(result["answer"])) // 2),
|
494 |
+
max_deletions=0,
|
495 |
+
)
|
496 |
+
try:
|
497 |
+
result["new_start"] = answer_match[0].start
|
498 |
+
result["new_end"] = answer_match[0].end
|
499 |
+
result["new_answer"] = answer_match[0].matched
|
500 |
+
result["link"] = (
|
501 |
+
search_results[0].link + "#:~:text=" + result["new_answer"].strip()
|
502 |
+
)
|
503 |
+
except:
|
504 |
+
result["new_start"] = result["start"]
|
505 |
+
result["new_end"] = result["end"]
|
506 |
+
result["new_answer"] = result["answer"]
|
507 |
+
result["original"] = preprocessor.preprocess(result["original"])
|
508 |
+
result["link"] = search_results[0].link
|
509 |
+
logger.info(f"Answers: {preprocessor.preprocess(result['new_answer'])}")
|
510 |
+
|
511 |
+
sorted_results = sorted(results, reverse=True, key=lambda x: x["score"])
|
512 |
+
|
513 |
+
return_dict = {}
|
514 |
+
return_dict["title"] = unquote(page_name)
|
515 |
+
return_dict["results"] = sorted_results
|
516 |
+
|
517 |
+
reader_time.stop()
|
518 |
+
logger.info(f"Total time spent: {reader_time.last + search_timer.last}")
|
519 |
+
return return_dict
|
backend/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
import psutil
|
4 |
+
import os
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
def get_current_ram_usage():
|
12 |
+
ram = psutil.virtual_memory()
|
13 |
+
return ram.available / 1024 / 1024 / 1024, ram.total / 1024 / 1024 / 1024
|
14 |
+
|
15 |
+
|
16 |
+
def download_models(models):
|
17 |
+
for model in tqdm(models, desc="Downloading models"):
|
18 |
+
logger.info(f"Downloading {model}")
|
19 |
+
for i in range(0, 5):
|
20 |
+
curr_dir = f"{model}/train_{i}/best_model/"
|
21 |
+
os.makedirs(curr_dir, exist_ok=True)
|
22 |
+
os.system(
|
23 |
+
f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/config.json -P {curr_dir}"
|
24 |
+
)
|
25 |
+
os.system(
|
26 |
+
f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/pytorch_model.bin -P {curr_dir}"
|
27 |
+
)
|
28 |
+
os.system(
|
29 |
+
f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/special_tokens_map.json -P {curr_dir}"
|
30 |
+
)
|
31 |
+
os.system(
|
32 |
+
f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/tokenizer_config.json -P {curr_dir}"
|
33 |
+
)
|
34 |
+
os.system(
|
35 |
+
f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/training_args.bin -P {curr_dir}"
|
36 |
+
)
|
37 |
+
os.system(
|
38 |
+
f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/vocab.txt -P {curr_dir}"
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def softmax(x):
|
43 |
+
return np.exp(x) / sum(np.exp(x))
|
44 |
+
|
45 |
+
|
46 |
+
def ga(file):
|
47 |
+
code = """
|
48 |
+
<!-- Global site tag (gtag.js) - Google Analytics -->
|
49 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-NH9HWCW08F"></script>
|
50 |
+
<script>
|
51 |
+
window.dataLayer = window.dataLayer || [];
|
52 |
+
function gtag(){dataLayer.push(arguments);}
|
53 |
+
gtag('js', new Date());
|
54 |
+
gtag('config', 'G-NH9HWCW08F');
|
55 |
+
</script>
|
56 |
+
"""
|
57 |
+
|
58 |
+
a = os.path.dirname(file) + "/static/index.html"
|
59 |
+
with open(a, "r") as f:
|
60 |
+
data = f.read()
|
61 |
+
if len(re.findall("G-", data)) == 0:
|
62 |
+
with open(a, "w") as ff:
|
63 |
+
newdata = re.sub("<head>", "<head>" + code, data)
|
64 |
+
ff.write(newdata)
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
openjdk-11-jre
|
2 |
+
curl
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==0.84.2
|
2 |
+
arabic-reshaper==2.1.3
|
3 |
+
python-bidi==0.4.2
|
4 |
+
PyArabic
|
5 |
+
farasapy==0.0.14
|
6 |
+
emoji==1.4.2
|
7 |
+
awesome_streamlit
|
8 |
+
torch==1.9.0
|
9 |
+
transformers==4.10.0
|
10 |
+
psutil==5.8.0
|
11 |
+
fuzzysearch==0.7.3
|
12 |
+
more-itertools==8.9.0
|
13 |
+
cookiecutter
|
14 |
+
git+https://github.com/dantru7/Google-Search-API
|
15 |
+
codetiming==1.3.0
|
16 |
+
htbuilder
|
17 |
+
wikipedia==1.4.0
|
test.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
from transformers import GPT2Tokenizer
|
3 |
+
|
4 |
+
# %%
|
5 |
+
tok = GPT2Tokenizer.from_pretrained("D:/ML/Models/aragpt2-medium", use_fast=False)
|
6 |
+
# %%
|
7 |
+
tok.pad_token = tok.eos_token
|
8 |
+
#%%
|
9 |
+
tok.pad_token_id = [tok.eos_token_id]
|
10 |
+
# %%
|