Spaces:
Runtime error
Runtime error
matthewfarant
commited on
Commit
•
012daa9
1
Parent(s):
0a02c34
Update functions/modelling_function.py
Browse files
functions/modelling_function.py
CHANGED
@@ -78,7 +78,7 @@ def category_reassign(row, reference_df, checked_category, threshold=70):
|
|
78 |
else:
|
79 |
return row['category_name']
|
80 |
|
81 |
-
def train_model(df, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
|
82 |
"""
|
83 |
This function trains the model using the configuration in config.yaml
|
84 |
|
@@ -98,7 +98,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
|
|
98 |
warnings.filterwarnings('ignore')
|
99 |
|
100 |
test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
|
101 |
-
train_df, test_df = train_test_split(df, test_size=test_size, stratify=df[
|
102 |
|
103 |
# Optional model configuration
|
104 |
model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
|
@@ -112,7 +112,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
|
|
112 |
|
113 |
# Create a ClassificationModel
|
114 |
model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
|
115 |
-
class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names']
|
116 |
|
117 |
if use_existing_model:
|
118 |
model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
|
@@ -125,7 +125,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
|
|
125 |
# Evaluate the model
|
126 |
result, model_outputs, wrong_predictions = model.eval_model(test_df)
|
127 |
preds = np.argmax(model_outputs, axis=1)
|
128 |
-
class_report =classification_report(test_df[
|
129 |
|
130 |
return model, preds, class_report, train_df, test_df, class_names
|
131 |
|
|
|
78 |
else:
|
79 |
return row['category_name']
|
80 |
|
81 |
+
def train_model(df, train_type, label_column, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
|
82 |
"""
|
83 |
This function trains the model using the configuration in config.yaml
|
84 |
|
|
|
98 |
warnings.filterwarnings('ignore')
|
99 |
|
100 |
test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
|
101 |
+
train_df, test_df = train_test_split(df, test_size=test_size, stratify=df[label_column])
|
102 |
|
103 |
# Optional model configuration
|
104 |
model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
|
|
|
112 |
|
113 |
# Create a ClassificationModel
|
114 |
model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
|
115 |
+
class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names'][train_type]
|
116 |
|
117 |
if use_existing_model:
|
118 |
model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
|
|
|
125 |
# Evaluate the model
|
126 |
result, model_outputs, wrong_predictions = model.eval_model(test_df)
|
127 |
preds = np.argmax(model_outputs, axis=1)
|
128 |
+
class_report =classification_report(test_df[label_column], preds, target_names=class_names)
|
129 |
|
130 |
return model, preds, class_report, train_df, test_df, class_names
|
131 |
|