matthewfarant commited on
Commit
012daa9
1 Parent(s): 0a02c34

Update functions/modelling_function.py

Browse files
Files changed (1) hide show
  1. functions/modelling_function.py +4 -4
functions/modelling_function.py CHANGED
@@ -78,7 +78,7 @@ def category_reassign(row, reference_df, checked_category, threshold=70):
78
  else:
79
  return row['category_name']
80
 
81
- def train_model(df, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
82
  """
83
  This function trains the model using the configuration in config.yaml
84
 
@@ -98,7 +98,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
98
  warnings.filterwarnings('ignore')
99
 
100
  test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
101
- train_df, test_df = train_test_split(df, test_size=test_size, stratify=df['category_name'])
102
 
103
  # Optional model configuration
104
  model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
@@ -112,7 +112,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
112
 
113
  # Create a ClassificationModel
114
  model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
115
- class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names']
116
 
117
  if use_existing_model:
118
  model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
@@ -125,7 +125,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
125
  # Evaluate the model
126
  result, model_outputs, wrong_predictions = model.eval_model(test_df)
127
  preds = np.argmax(model_outputs, axis=1)
128
- class_report =classification_report(test_df['category_name'], preds, target_names=class_names)
129
 
130
  return model, preds, class_report, train_df, test_df, class_names
131
 
 
78
  else:
79
  return row['category_name']
80
 
81
+ def train_model(df, train_type, label_column, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
82
  """
83
  This function trains the model using the configuration in config.yaml
84
 
 
98
  warnings.filterwarnings('ignore')
99
 
100
  test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
101
+ train_df, test_df = train_test_split(df, test_size=test_size, stratify=df[label_column])
102
 
103
  # Optional model configuration
104
  model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
 
112
 
113
  # Create a ClassificationModel
114
  model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
115
+ class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names'][train_type]
116
 
117
  if use_existing_model:
118
  model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
 
125
  # Evaluate the model
126
  result, model_outputs, wrong_predictions = model.eval_model(test_df)
127
  preds = np.argmax(model_outputs, axis=1)
128
+ class_report =classification_report(test_df[label_column], preds, target_names=class_names)
129
 
130
  return model, preds, class_report, train_df, test_df, class_names
131