fertilizer-catalog-engine / functions /modelling_function.py
matthewfarant's picture
Update functions/modelling_function.py
012daa9
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import os
import warnings
from rapidfuzz import fuzz, utils
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from scipy.special import softmax
def generate_training_data(df, text_column, label_column, external_table = None, external_column = None, add_external_table=False, sampling=True):
"""
This function generates training data for the model.
:param df: pandas.DataFrame, dataframe containing product name and category name
:param text_column: str, column name containing product name
:param label_column: str, column name containing category name
:param external_table: pandas.DataFrame, dataframe containing product name and category name
:param external_column: str, column name containing product name
:param add_external_table: bool, whether to add external table or not
:param sampling: bool, whether to do sampling or not
:return: pandas.DataFrame, dataframe containing product name and category name
"""
if os.listdir('training') == []:
print('Training folder is empty. Generating training data...')
units = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['excluded_words']
df['category_name'] = df[label_column].apply(lambda x: 'Fertilizer - High' if isinstance(x, list) and len(x) == 1 and 'Garden Soil & Fertilizers' in x else 'Pesticide - High' if isinstance(x, list) and len(x) == 1 and 'Weeds & Pest Control' in x else 'Fertilizer - Medium' if isinstance(x, list) and len(x) > 1 and 'Garden Soil & Fertilizers' in x else 'Pesticide - Medium' if isinstance(x, list) and len(x) > 1 and 'Weeds & Pest Control' in x else 'Others')
df = df[[text_column, 'category_name']]
# take only where category_name is Ferilizer - High or Pesticide - High or Others
df = df[df['category_name'].isin(['Fertilizer - High', 'Pesticide - High', 'Others'])]
# exclude product name that contains units AND category_name is Others
df = df[~(df[text_column].str.contains('|'.join(units)) & (df['category_name'] == 'Others'))]
if add_external_table:
external_table['category_name'] = 'Fertilizer - High'
external_table = external_table[[external_column, 'category_name']]
external_table.columns = [text_column, 'category_name']
training_df = pd.concat([external_table, df])
training_df.columns = ['product_name','category_name']
training_df['category_name'] = training_df['category_name'].apply(lambda x: 0 if x == 'Fertilizer - High' else 1 if x == 'Pesticide - High' else 2)
if sampling:
return pd.concat([training_df[training_df['category_name'] == 0].sample(n=1250), training_df[training_df['category_name'] == 1].sample(n=1250), training_df[training_df['category_name'] == 2].sample(n=1500)])
else:
return training_df
else:
return df
else:
training_df = pd.read_csv('training/training_data.csv')
return training_df
def category_reassign(row, reference_df, checked_category, threshold=70):
"""
This function reassigns the category name of a product based on the similarity score between the product name and the reference dataframe.
:param row: pandas.Series, row of dataframe
:param reference_df: pandas.DataFrame, dataframe containing product name and category name
:param checked_category: str, category name to be checked
:param threshold: int, threshold for similarity score
:return: str, category name
"""
if row['category_name'] == checked_category:
for i in range(len(reference_df)):
row2 = reference_df.iloc[i]
if row2['category_name'] != checked_category:
if fuzz.ratio(row['product_name'], row2['product_name'], processor= utils.default_process) >= threshold:
return row2['category_name']
return checked_category
else:
return row['category_name']
def train_model(df, train_type, label_column, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
"""
This function trains the model using the configuration in config.yaml
:param df: pandas.DataFrame, dataframe containing product name and category name
:param stratify: bool, whether to do stratified sampling or not
:param model_type: str, type of model to use
:param use_existing_model: bool, whether to use existing model or not
:param model_name: str, name of existing model
:return: simpletransformers.classification.ClassificationModel, model
:return: numpy.ndarray, predictions
:return: str, classification report
:return: pandas.DataFrame, training dataframe
:return: pandas.DataFrame, testing dataframe
:return: list, list of class names
"""
warnings.filterwarnings('ignore')
test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
train_df, test_df = train_test_split(df, test_size=test_size, stratify=df[label_column])
# Optional model configuration
model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
model_args = ClassificationArgs()
model_args.num_train_epochs = model_config['num_train_epochs']
model_args.train_batch_size = model_config['train_batch_size']
model_args.eval_batch_size = model_config['eval_batch_size']
model_args.overwrite_output_dir = model_config['overwrite_output_dir']
model_args.fp16 = model_config['fp16']
model_args.do_lower_case = model_config['do_lower_case']
# Create a ClassificationModel
model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names'][train_type]
if use_existing_model:
model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
else:
model = ClassificationModel(model_type, model_detail[model_type], num_labels=len(class_names), args=model_args, use_cuda=False)
# Train the model
model.train_model(train_df)
# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(test_df)
preds = np.argmax(model_outputs, axis=1)
class_report =classification_report(test_df[label_column], preds, target_names=class_names)
return model, preds, class_report, train_df, test_df, class_names
def save_model(model, model_name):
"""
This function saves the model.
:param model: simpletransformers.classification.ClassificationModel, model
:param model_name: str, name of model
:return: None
"""
model.model.save_pretrained(model_name)
model.tokenizer.save_pretrained(model_name)
model.config.save_pretrained(model_name + '/')
print('Model saved to ' + model_name + '/')
def show_confusion_matrix(test_category, preds, class_names):
"""
This function shows the confusion matrix.
:param test_category: numpy.ndarray, array of category name
:param preds: numpy.ndarray, array of predictions
:param class_names: list, list of class names
:return: matplotlib.axes._subplots.AxesSubplot, confusion matrix
"""
cm = confusion_matrix(test_category, preds)
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
hmap = sns.heatmap(df_cm, annot=True, fmt="d", cmap="Blues")
hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
plt.ylabel('True Topics')
plt.xlabel('Predicted Topics')
def predict_proba(model,text):
"""
This function predicts the probability of each class (in a text form).
:param model: simpletransformers.classification.ClassificationModel, model
:param text: str, text to predict
:return: numpy.ndarray, array of probabilities
"""
proba = softmax(model.predict([text])[1])[0]
print('-----------------------------')
print('Text to Predict: ', text)
print('Probability of each class:')
print('Fertilizer: ', proba[0])
print('Pesticide: ', proba[1])
print('Others: ', proba[2])
def predict_proba_array(model,text):
"""
This function predicts the probability of each class (in an array form).
:param model: simpletransformers.classification.ClassificationModel, model
:param text: str, text to predict
:return: numpy.ndarray, array of probabilities
"""
proba = softmax(model.predict([text])[1])[0]
return proba