Spaces:
Sleeping
Sleeping
Commit
·
e391945
1
Parent(s):
6a92f1f
تحديث نظام التدريب لاستخدام مجموعة بيانات اللهجات العربية
Browse files
train.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
3 |
-
from datasets import load_dataset
|
4 |
import numpy as np
|
5 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
6 |
|
@@ -16,12 +16,35 @@ def compute_metrics(pred):
|
|
16 |
'recall': recall
|
17 |
}
|
18 |
|
19 |
-
class
|
20 |
-
def __init__(self, model_name="CAMeL-Lab/bert-base-arabic-camelbert-msa"
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
-
|
|
|
23 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def tokenize_data(self, examples):
|
27 |
return self.tokenizer(
|
@@ -33,24 +56,26 @@ class ArabicTextTrainer:
|
|
33 |
|
34 |
def prepare_dataset(self, dataset):
|
35 |
tokenized_dataset = dataset.map(self.tokenize_data, batched=True)
|
36 |
-
tokenized_dataset = tokenized_dataset.remove_columns(['text'])
|
37 |
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')
|
38 |
tokenized_dataset.set_format('torch')
|
39 |
return tokenized_dataset
|
40 |
|
41 |
-
def train(self, train_dataset, eval_dataset=None, output_dir="./
|
|
|
42 |
training_args = TrainingArguments(
|
43 |
output_dir=output_dir,
|
44 |
num_train_epochs=num_train_epochs,
|
45 |
-
per_device_train_batch_size=
|
46 |
-
per_device_eval_batch_size=
|
47 |
warmup_steps=500,
|
48 |
weight_decay=0.01,
|
49 |
logging_dir='./logs',
|
50 |
-
logging_steps=
|
51 |
evaluation_strategy="epoch" if eval_dataset else "no",
|
52 |
save_strategy="epoch",
|
53 |
load_best_model_at_end=True if eval_dataset else False,
|
|
|
54 |
)
|
55 |
|
56 |
trainer = Trainer(
|
@@ -75,30 +100,16 @@ class ArabicTextTrainer:
|
|
75 |
print("تم حفظ النموذج بنجاح!")
|
76 |
|
77 |
def main():
|
78 |
-
# مثال على كيفية استخدام المدرب
|
79 |
-
# يمكنك تغيير مجموعة البيانات حسب احتياجاتك
|
80 |
print("تحميل مجموعة البيانات...")
|
|
|
81 |
|
82 |
-
|
83 |
-
# dataset = load_dataset("arabic_dataset_name")
|
84 |
-
|
85 |
-
# أو إنشاء مجموعة بيانات من قائمة
|
86 |
-
example_data = {
|
87 |
-
'text': ["نص إيجابي", "نص محايد", "نص سلبي"],
|
88 |
-
'label': [2, 1, 0] # 2: إيجابي، 1: محايد، 0: سلبي
|
89 |
-
}
|
90 |
-
dataset = Dataset.from_dict(example_data)
|
91 |
-
|
92 |
-
# تقسيم البيانات إلى مجموعتي تدريب واختبار
|
93 |
-
dataset = dataset.train_test_split(test_size=0.2)
|
94 |
-
|
95 |
-
trainer = ArabicTextTrainer()
|
96 |
|
97 |
-
|
98 |
train_dataset = trainer.prepare_dataset(dataset['train'])
|
99 |
-
eval_dataset = trainer.prepare_dataset(dataset['
|
100 |
|
101 |
-
|
102 |
trainer.train(train_dataset, eval_dataset)
|
103 |
|
104 |
if __name__ == "__main__":
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
3 |
+
from datasets import load_dataset
|
4 |
import numpy as np
|
5 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
6 |
|
|
|
16 |
'recall': recall
|
17 |
}
|
18 |
|
19 |
+
class ArabicDialectTrainer:
|
20 |
+
def __init__(self, model_name="CAMeL-Lab/bert-base-arabic-camelbert-msa"):
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
+
# 18 فئة للهجات العربية المختلفة
|
23 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=18)
|
24 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
self.model.to(self.device)
|
26 |
+
|
27 |
+
# تعريف تصنيف اللهجات
|
28 |
+
self.dialect_mapping = {
|
29 |
+
0: 'OM', # عُمان
|
30 |
+
1: 'SD', # السودان
|
31 |
+
2: 'SA', # السعودية
|
32 |
+
3: 'KW', # الكويت
|
33 |
+
4: 'QA', # قطر
|
34 |
+
5: 'LB', # لبنان
|
35 |
+
6: 'JO', # الأردن
|
36 |
+
7: 'SY', # سوريا
|
37 |
+
8: 'IQ', # العراق
|
38 |
+
9: 'MA', # المغرب
|
39 |
+
10: 'EG', # مصر
|
40 |
+
11: 'PL', # فلسطين
|
41 |
+
12: 'YE', # اليمن
|
42 |
+
13: 'BH', # البحرين
|
43 |
+
14: 'DZ', # الجزائر
|
44 |
+
15: 'AE', # الإمارات
|
45 |
+
16: 'TN', # تونس
|
46 |
+
17: 'LY' # ليبيا
|
47 |
+
}
|
48 |
|
49 |
def tokenize_data(self, examples):
|
50 |
return self.tokenizer(
|
|
|
56 |
|
57 |
def prepare_dataset(self, dataset):
|
58 |
tokenized_dataset = dataset.map(self.tokenize_data, batched=True)
|
59 |
+
tokenized_dataset = tokenized_dataset.remove_columns(['text', 'id'])
|
60 |
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')
|
61 |
tokenized_dataset.set_format('torch')
|
62 |
return tokenized_dataset
|
63 |
|
64 |
+
def train(self, train_dataset, eval_dataset=None, output_dir="./trained_model", num_train_epochs=3):
|
65 |
+
print("تهيئة معلمات التدريب...")
|
66 |
training_args = TrainingArguments(
|
67 |
output_dir=output_dir,
|
68 |
num_train_epochs=num_train_epochs,
|
69 |
+
per_device_train_batch_size=32,
|
70 |
+
per_device_eval_batch_size=32,
|
71 |
warmup_steps=500,
|
72 |
weight_decay=0.01,
|
73 |
logging_dir='./logs',
|
74 |
+
logging_steps=100,
|
75 |
evaluation_strategy="epoch" if eval_dataset else "no",
|
76 |
save_strategy="epoch",
|
77 |
load_best_model_at_end=True if eval_dataset else False,
|
78 |
+
metric_for_best_model="f1" if eval_dataset else None,
|
79 |
)
|
80 |
|
81 |
trainer = Trainer(
|
|
|
100 |
print("تم حفظ النموذج بنجاح!")
|
101 |
|
102 |
def main():
|
|
|
|
|
103 |
print("تحميل مجموعة البيانات...")
|
104 |
+
dataset = load_dataset("Abdelrahman-Rezk/Arabic_Dialect_Identification")
|
105 |
|
106 |
+
trainer = ArabicDialectTrainer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
print("تجهيز البيانات للتدريب...")
|
109 |
train_dataset = trainer.prepare_dataset(dataset['train'])
|
110 |
+
eval_dataset = trainer.prepare_dataset(dataset['validation'])
|
111 |
|
112 |
+
print("بدء عملية التدريب...")
|
113 |
trainer.train(train_dataset, eval_dataset)
|
114 |
|
115 |
if __name__ == "__main__":
|