ringorsolya commited on
Commit
3d669bb
1 Parent(s): 13cf9ab

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pooled_v4_xlmRoberta_training.xlsx filter=lfs diff=lfs merge=lfs -text
37
+ test_data.xlsx filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
Untitled.ipynb ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9e85b4fd-6c00-4d15-9a99-f461461bf660",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Requirement already satisfied: transformers in /home/p_babro/miniconda3/lib/python3.12/site-packages (4.43.4)\n",
14
+ "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (3.15.4)\n",
15
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.24.5)\n",
16
+ "Requirement already satisfied: numpy>=1.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (1.26.4)\n",
17
+ "Requirement already satisfied: packaging>=20.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (23.2)\n",
18
+ "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (6.0.1)\n",
19
+ "Requirement already satisfied: regex!=2019.12.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2024.7.24)\n",
20
+ "Requirement already satisfied: requests in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2.32.2)\n",
21
+ "Requirement already satisfied: safetensors>=0.4.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.4.4)\n",
22
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.19.1)\n",
23
+ "Requirement already satisfied: tqdm>=4.27 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (4.66.4)\n",
24
+ "Requirement already satisfied: fsspec>=2023.5.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.5.0)\n",
25
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n",
26
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.0.4)\n",
27
+ "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (3.7)\n",
28
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.2.2)\n",
29
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2024.7.4)\n",
30
+ "Note: you may need to restart the kernel to use updated packages.\n",
31
+ "Requirement already satisfied: datasets in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.20.0)\n",
32
+ "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.15.4)\n",
33
+ "Requirement already satisfied: numpy>=1.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (1.26.4)\n",
34
+ "Requirement already satisfied: pyarrow>=15.0.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (17.0.0)\n",
35
+ "Requirement already satisfied: pyarrow-hotfix in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.6)\n",
36
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.3.8)\n",
37
+ "Requirement already satisfied: pandas in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (2.2.2)\n",
38
+ "Requirement already satisfied: requests>=2.32.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (2.32.2)\n",
39
+ "Requirement already satisfied: tqdm>=4.66.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (4.66.4)\n",
40
+ "Requirement already satisfied: xxhash in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.4.1)\n",
41
+ "Requirement already satisfied: multiprocess in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.70.16)\n",
42
+ "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n",
43
+ "Requirement already satisfied: aiohttp in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.10.1)\n",
44
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.24.5)\n",
45
+ "Requirement already satisfied: packaging in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (23.2)\n",
46
+ "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (6.0.1)\n",
47
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (2.3.4)\n",
48
+ "Requirement already satisfied: aiosignal>=1.1.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n",
49
+ "Requirement already satisfied: attrs>=17.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (24.1.0)\n",
50
+ "Requirement already satisfied: frozenlist>=1.1.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.4.1)\n",
51
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (6.0.5)\n",
52
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.9.4)\n",
53
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub>=0.21.2->datasets) (4.12.2)\n",
54
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.0.4)\n",
55
+ "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.7)\n",
56
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.2.2)\n",
57
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2024.7.4)\n",
58
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2.9.0)\n",
59
+ "Requirement already satisfied: pytz>=2020.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n",
60
+ "Requirement already satisfied: tzdata>=2022.7 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n",
61
+ "Requirement already satisfied: six>=1.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
62
+ "Note: you may need to restart the kernel to use updated packages.\n",
63
+ "Requirement already satisfied: sentencepiece in /home/p_babro/miniconda3/lib/python3.12/site-packages (0.2.0)\n",
64
+ "Note: you may need to restart the kernel to use updated packages.\n",
65
+ "Requirement already satisfied: pandas in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.2.2)\n",
66
+ "Requirement already satisfied: openpyxl in /home/p_babro/miniconda3/lib/python3.12/site-packages (3.1.5)\n",
67
+ "Requirement already satisfied: numpy>=1.26.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (1.26.4)\n",
68
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2.9.0)\n",
69
+ "Requirement already satisfied: pytz>=2020.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n",
70
+ "Requirement already satisfied: tzdata>=2022.7 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n",
71
+ "Requirement already satisfied: et-xmlfile in /home/p_babro/miniconda3/lib/python3.12/site-packages (from openpyxl) (1.1.0)\n",
72
+ "Requirement already satisfied: six>=1.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
73
+ "Note: you may need to restart the kernel to use updated packages.\n"
74
+ ]
75
+ }
76
+ ],
77
+ "source": [
78
+ "%pip install transformers\n",
79
+ "%pip install datasets\n",
80
+ "%pip install sentencepiece\n",
81
+ "%pip install pandas openpyxl"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 2,
87
+ "id": "f72773a5-ddbc-43f7-a0b8-7b004a8b0db6",
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "name": "stdout",
92
+ "output_type": "stream",
93
+ "text": [
94
+ " labels text\n",
95
+ "0 1 Strach z osobního selhání často v kritických o...\n",
96
+ "1 5 Pre týchto ľudí treba nájsť riešenie.\n",
97
+ "2 5 Čestnými hosty byli bývalý spolkový prezident ...\n",
98
+ "3 4 Vaše milá slova mi opravdu zlepšila den.\n",
99
+ "4 4 Ďakujem mnohokrát! Z pochvaly máme radosť.\n"
100
+ ]
101
+ }
102
+ ],
103
+ "source": [
104
+ "import pandas as pd\n",
105
+ "\n",
106
+ "# Specify the file path\n",
107
+ "file_path = '/project/home/p_babro/p_babel/v4_slant/pooled_v4_xlmRoberta_training.xlsx'\n",
108
+ "\n",
109
+ "# Read the Excel file\n",
110
+ "df = pd.read_excel(file_path)\n",
111
+ "\n",
112
+ "# Display the DataFrame\n",
113
+ "print(df.head())"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 6,
119
+ "id": "e8c9c696-9308-4ac1-8364-798c04e7b54a",
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "Index(['labels', 'text'], dtype='object')\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "# Load data from Excel file\n",
132
+ "df = pd.read_excel(file_path)\n",
133
+ "\n",
134
+ "# Print the column names to verify\n",
135
+ "print(df.columns)\n"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 3,
141
+ "id": "86d92b6f-03b0-4df2-8f48-a34185180662",
142
+ "metadata": {},
143
+ "outputs": [
144
+ {
145
+ "name": "stderr",
146
+ "output_type": "stream",
147
+ "text": [
148
+ "Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
149
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification\n",
155
+ "\n",
156
+ "# Model and tokenizer initialization\n",
157
+ "tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')\n",
158
+ "model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base')"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 9,
164
+ "id": "24f34a63-31e4-4b57-bc72-a635cf3297a2",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "def start_train(df, model_name, batch_size, lr, max_length, num_epochs):\n",
169
+ "\n",
170
+ " # Prepare labels\n",
171
+ " label_encoder = LabelEncoder()\n",
172
+ " labels = df[label_column]\n",
173
+ " labels = label_encoder.fit_transform(labels)\n",
174
+ " num_labels = len(set(labels))\n",
175
+ "\n",
176
+ " # Hugging Face Datasets format\n",
177
+ " train_dataset = Dataset.from_pandas(train_data)\n",
178
+ " val_dataset = Dataset.from_pandas(val_data)\n",
179
+ " test_dataset = Dataset.from_pandas(test_data)\n",
180
+ "\n",
181
+ " # Load tokenizer\n",
182
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
183
+ "\n",
184
+ " # Tokenize\n",
185
+ " train_dataset = train_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=train_dataset.column_names)\n",
186
+ " val_dataset = val_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=val_dataset.column_names)\n",
187
+ " test_dataset = test_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=test_dataset.column_names)\n",
188
+ "\n",
189
+ " # Load model\n",
190
+ " model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, problem_type=\"multi_label_classification\")\n",
191
+ "\n",
192
+ " # Training arguments\n",
193
+ " training_args = TrainingArguments(\n",
194
+ " output_dir=drive_folder_to_save,\n",
195
+ " logging_dir=drive_folder_to_save,\n",
196
+ " logging_strategy='epoch',\n",
197
+ " logging_steps=100,\n",
198
+ " num_train_epochs=num_epochs,\n",
199
+ " per_device_train_batch_size=batch_size,\n",
200
+ " per_device_eval_batch_size=batch_size,\n",
201
+ " learning_rate=lr,\n",
202
+ " seed=42,\n",
203
+ " save_strategy='epoch',\n",
204
+ " save_steps=100,\n",
205
+ " evaluation_strategy='epoch',\n",
206
+ " eval_steps=100,\n",
207
+ " save_total_limit=1,\n",
208
+ " load_best_model_at_end=True,\n",
209
+ " )\n",
210
+ "\n",
211
+ " # Create trainer\n",
212
+ " trainer = Trainer(\n",
213
+ " model=model,\n",
214
+ " args=training_args,\n",
215
+ " train_dataset=train_dataset,\n",
216
+ " eval_dataset=val_dataset,\n",
217
+ " compute_metrics=compute_metrics,\n",
218
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n",
219
+ " )\n",
220
+ "\n",
221
+ " # Train model\n",
222
+ " trainer.train()\n",
223
+ "\n",
224
+ " # Evaluate results\n",
225
+ " predictions = trainer.predict(test_dataset).predictions\n",
226
+ " preds = np.argmax(predictions, axis=1)\n",
227
+ " accuracy = accuracy_score(test_data[label_column], preds)\n",
228
+ " print(f'Accuracy: {accuracy}')\n",
229
+ " precision, recall, f1, _ = precision_recall_fscore_support(test_data[label_column], preds, average='weighted')\n",
230
+ " print(f'Accuracy: {accuracy}')\n",
231
+ " print(f'Precision: {precision}')\n",
232
+ " print(f'Recall: {recall}')\n",
233
+ " print(f'F1 Score: {f1}')\n",
234
+ "\n",
235
+ " # Save model\n",
236
+ " trainer.save_model(folder_to_save)\n"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 7,
242
+ "id": "669ef024-3b2c-47c3-954c-de1e2b50f1d6",
243
+ "metadata": {},
244
+ "outputs": [
245
+ {
246
+ "name": "stdout",
247
+ "output_type": "stream",
248
+ "text": [
249
+ "Requirement already satisfied: pandas in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.2.2)\n",
250
+ "Requirement already satisfied: openpyxl in /home/p_babro/miniconda3/lib/python3.12/site-packages (3.1.5)\n",
251
+ "Requirement already satisfied: transformers in /home/p_babro/miniconda3/lib/python3.12/site-packages (4.43.4)\n",
252
+ "Requirement already satisfied: datasets in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.20.0)\n",
253
+ "Requirement already satisfied: evaluate in /home/p_babro/miniconda3/lib/python3.12/site-packages (0.4.2)\n",
254
+ "Requirement already satisfied: scikit-learn in /home/p_babro/miniconda3/lib/python3.12/site-packages (1.5.1)\n",
255
+ "Requirement already satisfied: numpy>=1.26.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (1.26.4)\n",
256
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2.9.0)\n",
257
+ "Requirement already satisfied: pytz>=2020.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n",
258
+ "Requirement already satisfied: tzdata>=2022.7 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n",
259
+ "Requirement already satisfied: et-xmlfile in /home/p_babro/miniconda3/lib/python3.12/site-packages (from openpyxl) (1.1.0)\n",
260
+ "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (3.15.4)\n",
261
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.24.5)\n",
262
+ "Requirement already satisfied: packaging>=20.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (23.2)\n",
263
+ "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (6.0.1)\n",
264
+ "Requirement already satisfied: regex!=2019.12.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2024.7.24)\n",
265
+ "Requirement already satisfied: requests in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2.32.2)\n",
266
+ "Requirement already satisfied: safetensors>=0.4.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.4.4)\n",
267
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.19.1)\n",
268
+ "Requirement already satisfied: tqdm>=4.27 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (4.66.4)\n",
269
+ "Requirement already satisfied: pyarrow>=15.0.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (17.0.0)\n",
270
+ "Requirement already satisfied: pyarrow-hotfix in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.6)\n",
271
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.3.8)\n",
272
+ "Requirement already satisfied: xxhash in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.4.1)\n",
273
+ "Requirement already satisfied: multiprocess in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.70.16)\n",
274
+ "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n",
275
+ "Requirement already satisfied: aiohttp in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.10.1)\n",
276
+ "Requirement already satisfied: scipy>=1.6.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from scikit-learn) (1.14.0)\n",
277
+ "Requirement already satisfied: joblib>=1.2.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from scikit-learn) (1.4.2)\n",
278
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from scikit-learn) (3.5.0)\n",
279
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (2.3.4)\n",
280
+ "Requirement already satisfied: aiosignal>=1.1.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n",
281
+ "Requirement already satisfied: attrs>=17.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (24.1.0)\n",
282
+ "Requirement already satisfied: frozenlist>=1.1.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.4.1)\n",
283
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (6.0.5)\n",
284
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.9.4)\n",
285
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n",
286
+ "Requirement already satisfied: six>=1.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
287
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.0.4)\n",
288
+ "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (3.7)\n",
289
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.2.2)\n",
290
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2024.7.4)\n",
291
+ "Note: you may need to restart the kernel to use updated packages.\n",
292
+ "Train data shape: (186137, 2)\n",
293
+ "Val data shape: (23267, 2)\n",
294
+ "Test data shape: (23268, 2)\n",
295
+ "/project/home/p_babro/p_babel/v4_slant/test_data.xlsx saved!\n"
296
+ ]
297
+ },
298
+ {
299
+ "data": {
300
+ "application/vnd.jupyter.widget-view+json": {
301
+ "model_id": "939855a37b3b43f3a1b5a54f3b7a1031",
302
+ "version_major": 2,
303
+ "version_minor": 0
304
+ },
305
+ "text/plain": [
306
+ "Map: 0%| | 0/186137 [00:00<?, ? examples/s]"
307
+ ]
308
+ },
309
+ "metadata": {},
310
+ "output_type": "display_data"
311
+ },
312
+ {
313
+ "data": {
314
+ "application/vnd.jupyter.widget-view+json": {
315
+ "model_id": "52f14f45fad1420f820b3b96d994ae53",
316
+ "version_major": 2,
317
+ "version_minor": 0
318
+ },
319
+ "text/plain": [
320
+ "Map: 0%| | 0/23267 [00:00<?, ? examples/s]"
321
+ ]
322
+ },
323
+ "metadata": {},
324
+ "output_type": "display_data"
325
+ },
326
+ {
327
+ "data": {
328
+ "application/vnd.jupyter.widget-view+json": {
329
+ "model_id": "384f14f4adf044eda2f43960c9e5d7dc",
330
+ "version_major": 2,
331
+ "version_minor": 0
332
+ },
333
+ "text/plain": [
334
+ "Map: 0%| | 0/23268 [00:00<?, ? examples/s]"
335
+ ]
336
+ },
337
+ "metadata": {},
338
+ "output_type": "display_data"
339
+ },
340
+ {
341
+ "name": "stderr",
342
+ "output_type": "stream",
343
+ "text": [
344
+ "Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
345
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
346
+ "/home/p_babro/miniconda3/lib/python3.12/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
347
+ " warnings.warn(\n",
348
+ "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
349
+ ]
350
+ },
351
+ {
352
+ "data": {
353
+ "text/html": [
354
+ "\n",
355
+ " <div>\n",
356
+ " \n",
357
+ " <progress value='58170' max='116340' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
358
+ " [ 58170/116340 2:33:12 < 2:33:12, 6.33 it/s, Epoch 5/10]\n",
359
+ " </div>\n",
360
+ " <table border=\"1\" class=\"dataframe\">\n",
361
+ " <thead>\n",
362
+ " <tr style=\"text-align: left;\">\n",
363
+ " <th>Epoch</th>\n",
364
+ " <th>Training Loss</th>\n",
365
+ " <th>Validation Loss</th>\n",
366
+ " <th>Accuracy</th>\n",
367
+ " <th>Precision</th>\n",
368
+ " <th>Recall</th>\n",
369
+ " <th>F1</th>\n",
370
+ " </tr>\n",
371
+ " </thead>\n",
372
+ " <tbody>\n",
373
+ " <tr>\n",
374
+ " <td>1</td>\n",
375
+ " <td>0.177300</td>\n",
376
+ " <td>0.141849</td>\n",
377
+ " <td>0.817252</td>\n",
378
+ " <td>0.818918</td>\n",
379
+ " <td>0.817252</td>\n",
380
+ " <td>0.817750</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2</td>\n",
384
+ " <td>0.134500</td>\n",
385
+ " <td>0.133338</td>\n",
386
+ " <td>0.830103</td>\n",
387
+ " <td>0.830676</td>\n",
388
+ " <td>0.830103</td>\n",
389
+ " <td>0.830280</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <td>3</td>\n",
393
+ " <td>0.120100</td>\n",
394
+ " <td>0.130069</td>\n",
395
+ " <td>0.834229</td>\n",
396
+ " <td>0.834528</td>\n",
397
+ " <td>0.834229</td>\n",
398
+ " <td>0.833342</td>\n",
399
+ " </tr>\n",
400
+ " <tr>\n",
401
+ " <td>4</td>\n",
402
+ " <td>0.110600</td>\n",
403
+ " <td>0.132942</td>\n",
404
+ " <td>0.835045</td>\n",
405
+ " <td>0.834790</td>\n",
406
+ " <td>0.835045</td>\n",
407
+ " <td>0.834567</td>\n",
408
+ " </tr>\n",
409
+ " <tr>\n",
410
+ " <td>5</td>\n",
411
+ " <td>0.103200</td>\n",
412
+ " <td>0.131241</td>\n",
413
+ " <td>0.833455</td>\n",
414
+ " <td>0.833605</td>\n",
415
+ " <td>0.833455</td>\n",
416
+ " <td>0.833047</td>\n",
417
+ " </tr>\n",
418
+ " </tbody>\n",
419
+ "</table><p>"
420
+ ],
421
+ "text/plain": [
422
+ "<IPython.core.display.HTML object>"
423
+ ]
424
+ },
425
+ "metadata": {},
426
+ "output_type": "display_data"
427
+ },
428
+ {
429
+ "data": {
430
+ "application/vnd.jupyter.widget-view+json": {
431
+ "model_id": "58682414ea364b57bb8cf08b0df06e4f",
432
+ "version_major": 2,
433
+ "version_minor": 0
434
+ },
435
+ "text/plain": [
436
+ "Downloading builder script: 0%| | 0.00/4.20k [00:00<?, ?B/s]"
437
+ ]
438
+ },
439
+ "metadata": {},
440
+ "output_type": "display_data"
441
+ },
442
+ {
443
+ "data": {
444
+ "text/html": [],
445
+ "text/plain": [
446
+ "<IPython.core.display.HTML object>"
447
+ ]
448
+ },
449
+ "metadata": {},
450
+ "output_type": "display_data"
451
+ },
452
+ {
453
+ "name": "stdout",
454
+ "output_type": "stream",
455
+ "text": [
456
+ "Accuracy: 0.8367715317173801\n",
457
+ "Precision: 0.8369187930273877\n",
458
+ "Recall: 0.8367715317173801\n",
459
+ "F1 Score: 0.8360611942926541\n"
460
+ ]
461
+ }
462
+ ],
463
+ "source": [
464
+ "# Install necessary libraries\n",
465
+ "%pip install pandas openpyxl transformers datasets evaluate scikit-learn\n",
466
+ "\n",
467
+ "# Import necessary libraries\n",
468
+ "import pandas as pd\n",
469
+ "import numpy as np\n",
470
+ "import torch\n",
471
+ "from sklearn.model_selection import train_test_split\n",
472
+ "from sklearn.preprocessing import LabelEncoder\n",
473
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
474
+ "from transformers import (XLMRobertaTokenizer, XLMRobertaForSequenceClassification, AutoTokenizer,\n",
475
+ " AutoModelForSequenceClassification, Trainer, TrainingArguments)\n",
476
+ "from datasets import Dataset\n",
477
+ "from transformers.trainer_callback import EarlyStoppingCallback\n",
478
+ "import evaluate\n",
479
+ "from typing import List, Tuple\n",
480
+ "\n",
481
+ "# Define paths and columns\n",
482
+ "file_path = '/project/home/p_babro/p_babel/v4_slant/pooled_v4_xlmRoberta_training.xlsx'\n",
483
+ "text_column = 'text' # Replace with your actual text column name\n",
484
+ "label_column = 'labels' # Replace with your actual label column name\n",
485
+ "drive_folder_to_save = '/project/home/p_babro/p_babel/v4_slant' # Replace with your actual save folder path\n",
486
+ "\n",
487
+ "# Define functions\n",
488
+ "def load_data_from_excel(df, text_column: str, label_column: str) -> Tuple[List, List]:\n",
489
+ " return df[text_column].tolist(), df[label_column].tolist()\n",
490
+ "\n",
491
+ "def tokenize_dataset(data, tokenizer, max_length, num_labels):\n",
492
+ " tokenized = tokenizer(data[text_column],\n",
493
+ " max_length=max_length,\n",
494
+ " truncation=True,\n",
495
+ " padding=\"max_length\")\n",
496
+ "\n",
497
+ " labels = [x for x in data[label_column]]\n",
498
+ " labels_tensor = torch.as_tensor(labels)\n",
499
+ " labels_binary = torch.nn.functional.one_hot(labels_tensor, num_classes=num_labels).float()\n",
500
+ "\n",
501
+ " tokenized['labels'] = labels_binary\n",
502
+ "\n",
503
+ " return tokenized\n",
504
+ "\n",
505
+ "def compute_metrics(eval_pred):\n",
506
+ " metric = evaluate.load(\"accuracy\")\n",
507
+ " logits, labels = eval_pred\n",
508
+ " predictions = np.argmax(logits, axis=1)\n",
509
+ " reference_labels = [np.argmax(label) for label in labels]\n",
510
+ " precision, recall, f1, _ = precision_recall_fscore_support(reference_labels, predictions, average='weighted')\n",
511
+ " accuracy = accuracy_score(reference_labels, predictions)\n",
512
+ " return {\n",
513
+ " 'accuracy': accuracy,\n",
514
+ " 'precision': precision,\n",
515
+ " 'recall': recall,\n",
516
+ " 'f1': f1\n",
517
+ " }\n",
518
+ "\n",
519
+ "# Load data from Excel file\n",
520
+ "df = pd.read_excel(file_path)\n",
521
+ "texts, labels = load_data_from_excel(df, text_column, label_column)\n",
522
+ "\n",
523
+ "# Split the data\n",
524
+ "data = pd.DataFrame({text_column: texts, label_column: labels})\n",
525
+ "train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)\n",
526
+ "val_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)\n",
527
+ "\n",
528
+ "print(f'Train data shape: {train_data.shape}')\n",
529
+ "print(f'Val data shape: {val_data.shape}')\n",
530
+ "print(f'Test data shape: {test_data.shape}')\n",
531
+ "\n",
532
+ "# Save test data to Excel\n",
533
+ "test_data.to_excel(f'{drive_folder_to_save}/test_data.xlsx', index=False)\n",
534
+ "print(f'{drive_folder_to_save}/test_data.xlsx saved!')\n",
535
+ "\n",
536
+ "def start_train(df, model_name, batch_size, lr, max_length, num_epochs):\n",
537
+ "\n",
538
+ " # Prepare labels\n",
539
+ " label_encoder = LabelEncoder()\n",
540
+ " labels = df[label_column]\n",
541
+ " labels = label_encoder.fit_transform(labels)\n",
542
+ " num_labels = len(set(labels))\n",
543
+ "\n",
544
+ " # Hugging Face Datasets format\n",
545
+ " train_dataset = Dataset.from_pandas(train_data)\n",
546
+ " val_dataset = Dataset.from_pandas(val_data)\n",
547
+ " test_dataset = Dataset.from_pandas(test_data)\n",
548
+ "\n",
549
+ " # Load tokenizer\n",
550
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
551
+ "\n",
552
+ " # Tokenize\n",
553
+ " train_dataset = train_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=train_dataset.column_names)\n",
554
+ " val_dataset = val_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=val_dataset.column_names)\n",
555
+ " test_dataset = test_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=test_dataset.column_names)\n",
556
+ "\n",
557
+ " # Load model\n",
558
+ " model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, problem_type=\"multi_label_classification\")\n",
559
+ "\n",
560
+ " # Training arguments\n",
561
+ " training_args = TrainingArguments(\n",
562
+ " output_dir=drive_folder_to_save,\n",
563
+ " logging_dir=drive_folder_to_save,\n",
564
+ " logging_strategy='epoch',\n",
565
+ " logging_steps=100,\n",
566
+ " num_train_epochs=num_epochs,\n",
567
+ " per_device_train_batch_size=batch_size,\n",
568
+ " per_device_eval_batch_size=batch_size,\n",
569
+ " learning_rate=lr,\n",
570
+ " seed=42,\n",
571
+ " save_strategy='epoch',\n",
572
+ " save_steps=100,\n",
573
+ " evaluation_strategy='epoch',\n",
574
+ " eval_steps=100,\n",
575
+ " save_total_limit=1,\n",
576
+ " load_best_model_at_end=True,\n",
577
+ " )\n",
578
+ "\n",
579
+ " # Create trainer\n",
580
+ " trainer = Trainer(\n",
581
+ " model=model,\n",
582
+ " args=training_args,\n",
583
+ " train_dataset=train_dataset,\n",
584
+ " eval_dataset=val_dataset,\n",
585
+ " compute_metrics=compute_metrics,\n",
586
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n",
587
+ " )\n",
588
+ "\n",
589
+ " # Train model\n",
590
+ " trainer.train()\n",
591
+ "\n",
592
+ " # Evaluate results\n",
593
+ " predictions = trainer.predict(test_dataset).predictions\n",
594
+ " preds = np.argmax(predictions, axis=1)\n",
595
+ " accuracy = accuracy_score(test_data[label_column], preds)\n",
596
+ " print(f'Accuracy: {accuracy}')\n",
597
+ " precision, recall, f1, _ = precision_recall_fscore_support(test_data[label_column], preds, average='weighted')\n",
598
+ " print(f'Precision: {precision}')\n",
599
+ " print(f'Recall: {recall}')\n",
600
+ " print(f'F1 Score: {f1}')\n",
601
+ "\n",
602
+ " # Save model\n",
603
+ " trainer.save_model(drive_folder_to_save)\n",
604
+ "\n",
605
+ "# Define training parameters\n",
606
+ "model_name = 'xlm-roberta-base'\n",
607
+ "batch_size = 16\n",
608
+ "learning_rate = 5e-6\n",
609
+ "max_length = 128\n",
610
+ "num_epochs = 10\n",
611
+ "\n",
612
+ "# Start training\n",
613
+ "start_train(df, model_name, batch_size, learning_rate, max_length, num_epochs)\n"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": 12,
619
+ "id": "b47790d8-771e-45b9-a5c7-31d939de35b5",
620
+ "metadata": {},
621
+ "outputs": [
622
+ {
623
+ "name": "stdout",
624
+ "output_type": "stream",
625
+ "text": [
626
+ "Requirement already satisfied: transformers in /home/p_babro/miniconda3/lib/python3.12/site-packages (4.43.4)\n",
627
+ "Requirement already satisfied: huggingface_hub in /home/p_babro/miniconda3/lib/python3.12/site-packages (0.24.5)\n",
628
+ "Collecting python-dotenv\n",
629
+ " Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)\n",
630
+ "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (3.15.4)\n",
631
+ "Requirement already satisfied: numpy>=1.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (1.26.4)\n",
632
+ "Requirement already satisfied: packaging>=20.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (23.2)\n",
633
+ "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (6.0.1)\n",
634
+ "Requirement already satisfied: regex!=2019.12.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2024.7.24)\n",
635
+ "Requirement already satisfied: requests in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2.32.2)\n",
636
+ "Requirement already satisfied: safetensors>=0.4.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.4.4)\n",
637
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.19.1)\n",
638
+ "Requirement already satisfied: tqdm>=4.27 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (4.66.4)\n",
639
+ "Requirement already satisfied: fsspec>=2023.5.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface_hub) (2024.5.0)\n",
640
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface_hub) (4.12.2)\n",
641
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.0.4)\n",
642
+ "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (3.7)\n",
643
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.2.2)\n",
644
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2024.7.4)\n",
645
+ "Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)\n",
646
+ "Installing collected packages: python-dotenv\n",
647
+ "Successfully installed python-dotenv-1.0.1\n",
648
+ "Note: you may need to restart the kernel to use updated packages.\n"
649
+ ]
650
+ },
651
+ {
652
+ "ename": "ValueError",
653
+ "evalue": "Please set the HF_TOKEN environment variable.",
654
+ "output_type": "error",
655
+ "traceback": [
656
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
657
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
658
+ "Cell \u001b[0;32mIn[12], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m hf_token \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHF_TOKEN\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m hf_token:\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease set the HF_TOKEN environment variable.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# Define your save directory and Hugging Face repository information\u001b[39;00m\n\u001b[1;32m 19\u001b[0m drive_folder_to_save \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/project/home/p_babro/p_babel/v4_slant\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
659
+ "\u001b[0;31mValueError\u001b[0m: Please set the HF_TOKEN environment variable."
660
+ ]
661
+ }
662
+ ],
663
+ "source": [
664
+ "# Install necessary libraries\n",
665
+ "%pip install transformers huggingface_hub python-dotenv\n",
666
+ "\n",
667
+ "# Import necessary libraries\n",
668
+ "from transformers import AutoTokenizer\n",
669
+ "from huggingface_hub import HfApi\n",
670
+ "import os\n",
671
+ "from dotenv import load_dotenv\n",
672
+ "\n",
673
+ "# Load environment variables from .env file\n",
674
+ "load_dotenv()\n",
675
+ "\n",
676
+ "# Retrieve the token from the environment variable\n",
677
+ "hf_token = os.getenv(\"HF_TOKEN\")\n",
678
+ "if not hf_token:\n",
679
+ " raise ValueError(\"Please set the HF_TOKEN environment variable.\")\n",
680
+ "\n",
681
+ "# Define your save directory and Hugging Face repository information\n",
682
+ "drive_folder_to_save = '/project/home/p_babro/p_babel/v4_slant'\n",
683
+ "repo_id = \"ringorsolya/Emotion_RoBERTa_pooled_V4\"\n",
684
+ "\n",
685
+ "# Set environment variable to avoid the parallelism warning\n",
686
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
687
+ "\n",
688
+ "# Initialize the HfApi with your token\n",
689
+ "api = HfApi()\n",
690
+ "\n",
691
+ "# Ensure the folder exists and contains files\n",
692
+ "if os.path.exists(drive_folder_to_save) and os.listdir(drive_folder_to_save):\n",
693
+ " print(f\"Uploading folder {drive_folder_to_save} to Hugging Face repository {repo_id}\")\n",
694
+ " \n",
695
+ " # Upload the model folder to the Hugging Face repository\n",
696
+ " api.upload_folder(\n",
697
+ " folder_path=drive_folder_to_save,\n",
698
+ " repo_id=repo_id,\n",
699
+ " token=hf_token\n",
700
+ " )\n",
701
+ " \n",
702
+ " print(\"Folder upload completed.\")\n",
703
+ "else:\n",
704
+ " print(f\"The folder {drive_folder_to_save} does not exist or is empty.\")\n",
705
+ "\n",
706
+ "# Load the tokenizer (use the correct model name if different)\n",
707
+ "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-base\") # Or the name of your saved model\n",
708
+ "\n",
709
+ "# Push the tokenizer to the Hugging Face repository\n",
710
+ "tokenizer.push_to_hub(\n",
711
+ " repo_id=repo_id,\n",
712
+ " use_auth_token=hf_token\n",
713
+ ")\n",
714
+ "\n",
715
+ "print(\"Tokenizer upload completed.\")\n"
716
+ ]
717
+ },
718
+ {
719
+ "cell_type": "code",
720
+ "execution_count": null,
721
+ "id": "5bcd6e0f-f56f-4d6f-a323-04286e7d06f8",
722
+ "metadata": {},
723
+ "outputs": [],
724
+ "source": []
725
+ }
726
+ ],
727
+ "metadata": {
728
+ "kernelspec": {
729
+ "display_name": "Python 3 (ipykernel)",
730
+ "language": "python",
731
+ "name": "python3"
732
+ },
733
+ "language_info": {
734
+ "codemirror_mode": {
735
+ "name": "ipython",
736
+ "version": 3
737
+ },
738
+ "file_extension": ".py",
739
+ "mimetype": "text/x-python",
740
+ "name": "python",
741
+ "nbconvert_exporter": "python",
742
+ "pygments_lexer": "ipython3",
743
+ "version": "3.12.2"
744
+ }
745
+ },
746
+ "nbformat": 4,
747
+ "nbformat_minor": 5
748
+ }
checkpoint-34902/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "xlm-roberta-base",
3
+ "architectures": [
4
+ "XLMRobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5"
20
+ },
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 3072,
23
+ "label2id": {
24
+ "LABEL_0": 0,
25
+ "LABEL_1": 1,
26
+ "LABEL_2": 2,
27
+ "LABEL_3": 3,
28
+ "LABEL_4": 4,
29
+ "LABEL_5": 5
30
+ },
31
+ "layer_norm_eps": 1e-05,
32
+ "max_position_embeddings": 514,
33
+ "model_type": "xlm-roberta",
34
+ "num_attention_heads": 12,
35
+ "num_hidden_layers": 12,
36
+ "output_past": true,
37
+ "pad_token_id": 1,
38
+ "position_embedding_type": "absolute",
39
+ "problem_type": "multi_label_classification",
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.43.4",
42
+ "type_vocab_size": 1,
43
+ "use_cache": true,
44
+ "vocab_size": 250002
45
+ }
checkpoint-34902/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc5cad10fb2e65429f1414393260729d04922c20b143fb2308c4ebf89e178a75
3
+ size 1112217312
checkpoint-34902/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a997112d66429766ee61cd0e1349976747bb1fcd91512108b712d3abf74309f3
3
+ size 2224554234
checkpoint-34902/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bca90ea2211b516c7fae8e0149a76a5ac06875de579dc64e16b5f64d6f22fa2d
3
+ size 14244
checkpoint-34902/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:076211ed8fa209fc22da0556749ec591ad6116d3c67eabbd8264ccd06b091a34
3
+ size 1064
checkpoint-34902/trainer_state.json ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.13006852567195892,
3
+ "best_model_checkpoint": "/project/home/p_babro/p_babel/v4_slant/checkpoint-34902",
4
+ "epoch": 3.0,
5
+ "eval_steps": 100,
6
+ "global_step": 34902,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 1.0,
13
+ "grad_norm": 3.1385393142700195,
14
+ "learning_rate": 4.5e-06,
15
+ "loss": 0.1773,
16
+ "step": 11634
17
+ },
18
+ {
19
+ "epoch": 1.0,
20
+ "eval_accuracy": 0.8172519018352172,
21
+ "eval_f1": 0.8177504983784618,
22
+ "eval_loss": 0.14184926450252533,
23
+ "eval_precision": 0.8189176699515038,
24
+ "eval_recall": 0.8172519018352172,
25
+ "eval_runtime": 60.5693,
26
+ "eval_samples_per_second": 384.138,
27
+ "eval_steps_per_second": 24.022,
28
+ "step": 11634
29
+ },
30
+ {
31
+ "epoch": 2.0,
32
+ "grad_norm": 6.301619052886963,
33
+ "learning_rate": 4.000000000000001e-06,
34
+ "loss": 0.1345,
35
+ "step": 23268
36
+ },
37
+ {
38
+ "epoch": 2.0,
39
+ "eval_accuracy": 0.8301027205913956,
40
+ "eval_f1": 0.8302797870565826,
41
+ "eval_loss": 0.13333828747272491,
42
+ "eval_precision": 0.8306761234771176,
43
+ "eval_recall": 0.8301027205913956,
44
+ "eval_runtime": 59.9815,
45
+ "eval_samples_per_second": 387.903,
46
+ "eval_steps_per_second": 24.257,
47
+ "step": 23268
48
+ },
49
+ {
50
+ "epoch": 3.0,
51
+ "grad_norm": 5.983363151550293,
52
+ "learning_rate": 3.5e-06,
53
+ "loss": 0.1201,
54
+ "step": 34902
55
+ },
56
+ {
57
+ "epoch": 3.0,
58
+ "eval_accuracy": 0.8342287359779946,
59
+ "eval_f1": 0.8333415258252606,
60
+ "eval_loss": 0.13006852567195892,
61
+ "eval_precision": 0.8345277590208003,
62
+ "eval_recall": 0.8342287359779946,
63
+ "eval_runtime": 56.7072,
64
+ "eval_samples_per_second": 410.301,
65
+ "eval_steps_per_second": 25.658,
66
+ "step": 34902
67
+ }
68
+ ],
69
+ "logging_steps": 100,
70
+ "max_steps": 116340,
71
+ "num_input_tokens_seen": 0,
72
+ "num_train_epochs": 10,
73
+ "save_steps": 100,
74
+ "stateful_callbacks": {
75
+ "EarlyStoppingCallback": {
76
+ "args": {
77
+ "early_stopping_patience": 2,
78
+ "early_stopping_threshold": 0.0
79
+ },
80
+ "attributes": {
81
+ "early_stopping_patience_counter": 0
82
+ }
83
+ },
84
+ "TrainerControl": {
85
+ "args": {
86
+ "should_epoch_stop": false,
87
+ "should_evaluate": false,
88
+ "should_log": false,
89
+ "should_save": true,
90
+ "should_training_stop": false
91
+ },
92
+ "attributes": {}
93
+ }
94
+ },
95
+ "total_flos": 3.673234605593549e+16,
96
+ "train_batch_size": 16,
97
+ "trial_name": null,
98
+ "trial_params": null
99
+ }
checkpoint-34902/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6b93185a28c952ca066cf86488db93fa194fd44cf76b3b59fc1937acd9fdeab
3
+ size 5112
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "xlm-roberta-base",
3
+ "architectures": [
4
+ "XLMRobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5"
20
+ },
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 3072,
23
+ "label2id": {
24
+ "LABEL_0": 0,
25
+ "LABEL_1": 1,
26
+ "LABEL_2": 2,
27
+ "LABEL_3": 3,
28
+ "LABEL_4": 4,
29
+ "LABEL_5": 5
30
+ },
31
+ "layer_norm_eps": 1e-05,
32
+ "max_position_embeddings": 514,
33
+ "model_type": "xlm-roberta",
34
+ "num_attention_heads": 12,
35
+ "num_hidden_layers": 12,
36
+ "output_past": true,
37
+ "pad_token_id": 1,
38
+ "position_embedding_type": "absolute",
39
+ "problem_type": "multi_label_classification",
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.43.4",
42
+ "type_vocab_size": 1,
43
+ "use_cache": true,
44
+ "vocab_size": 250002
45
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc5cad10fb2e65429f1414393260729d04922c20b143fb2308c4ebf89e178a75
3
+ size 1112217312
pooled_v4_xlmRoberta_training.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c641f391ef8fa3c68a7509cc06d3d56cb2dd58d4955f0a46b89359628d78afa
3
+ size 8796228
test_data.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b873bd285540583a83067e6feaa2433e0e547993896cfce879f47d135730847
3
+ size 1544302
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6b93185a28c952ca066cf86488db93fa194fd44cf76b3b59fc1937acd9fdeab
3
+ size 5112