import pandas as pd |
import numpy as np |
from tqdm import tqdm |
import re |
import fire |
import json |
from tqdm import tqdm |
import logging |
from pipeline import Pipeline |
import copy |
from download_models import check_if_exist |
""" |
Install dependecies by running: pip3 install -r requirements.txt |
Running command example: |
python3 label_extraction.py --path_to_file data.xlsx --column_name report --save_predictions predictions.xlsx --save_json output.json |
""" |
def data_extraction(path_to_file:str, column_name:str, higher_model:str="clinicalBERT", all_label_model="single_tfidf", save_predictions:str=None, output_model_data=None,save_input=None, save_json:str=None): |
""" |
This program takes an excell/csv sheet and extract the higher order and cancer characteristics from pathology reports |
Input Options: |
1) path_to_file - Path to an excel/csv with pathology diagnosis: String (Required) |
2) column_name - Which column has the pathology diagnosis: String (Required) |
3) higher_model - Which version of higher order model to use: String (Required) |
4) all_label_model - Which version of all labels model to use: String (Required) |
5) save_predictions - Path to save output: String (Optional) |
6) output_model_data - Option to output model data to csv True/False (Optional) |
7) save_input - Option to output the input fields True/False (Optional) |
8) save_json - Path to save json analyis: String (Optional) |
""" |
data_orig = read_data(path_to_file) |
data_orig = data_orig.fillna("NA") |
data = data_orig.loc[:, ~data_orig.columns.str.contains('^Unnamed')][column_name].values |
predictions, json_output, higher_order_pred,all_labels_pred = {},[],[],[] |
if not check_if_exist(higher_model): |
print("\n\t ##### Please Download Model: " + str(higher_model) + "#####") |
exit() |
if not check_if_exist(all_label_model): |
print("\n\t ##### Please Download Model: " + str(all_label_model) + "#####") |
exit() |
model = Pipeline(bert_option=higher_model, branch_option=all_label_model) |
logging.info("\nRunning Predictions for data size of: " + str(len(data))) |
for index in tqdm(range(len(data))): |
d = data[index] |
preds,all_layer_hidden_states = model.run(d) |
predictions["sample_" + str(index)] = {} |
for ind,pred in enumerate(preds): |
predictions["sample_" + str(index)]["prediction_" + str(ind)] = pred |
for key,sample in predictions.items(): |
higher,all_p = [],[] |
for key,pred in sample.items(): |
for higher_order, sub_arr in pred.items(): |
higher.append(higher_order) |
for label,v in sub_arr['labels'].items(): |
all_p.append(label) |
higher_order_pred.append(" && ".join(x for x in higher)) |
all_labels_pred.append(" && ".join(x for x in all_p)) |
predictions_refact = copy.deepcopy(predictions) |
transformer_data, discriminator_data= [0 for x in range(len(data))], [0 for x in range(len(data))] |
for index in tqdm(range(len(data))): |
key = "sample_" + str(index) |
for k,v in predictions[key].items(): |
for k_s, v_s in v.items(): |
predictions_refact["sample_" + str(index)]["data"] = v_s['data'] |
predictions_refact["sample_" + str(index)]["transformer_data"] = v_s['transformer_data'] |
predictions_refact["sample_" + str(index)]["discriminator_data"] = v_s['word_analysis']['discriminator_data'] |
transformer_data[index] = v_s['transformer_data'] |
discriminator_data[index] = v_s['word_analysis']['discriminator_data'] |
del predictions_refact[key][k][k_s]['data'] |
del predictions_refact[key][k][k_s]['transformer_data'] |
del predictions_refact[key][k][k_s]['word_analysis']['discriminator_data'] |
json_output = predictions_refact |
if save_predictions!= None: |
logging.info("Saving Predictions") |
if output_model_data != None: |
all_preds = pd.DataFrame(list(zip(higher_order_pred, all_labels_pred,transformer_data,discriminator_data,data)), columns =['Higher Order',"All Labels", 'Higher Order Model Data','All Labels Model Data',column_name]) |
else: |
all_preds = pd.DataFrame(list(zip(higher_order_pred, all_labels_pred)), columns =['Higher Order',"All Labels"]) |
if save_input != None: |
all_preds = pd.concat([data_orig, all_preds], axis=1) |
try: |
all_preds.to_excel(save_predictions) |
except ValueError: |
try: |
all_preds.to_csv(save_predictions) |
except ValueError: |
logging.exception("Error while saving predictions " + str(e)) |
exit() |
logging.info("Done") |
if save_json!= None: |
logging.info("Saving Json") |
try: |
with open(save_json, 'w') as f: |
for k, v in json_output.items(): |
f.write('{'+str(k) + ':'+ str(v) + '\n') |
except ValueError: |
logging.exception("Error while saving json analysis " + str(e)) |
exit() |
logging.info("Done") |
def read_data(path_to_file): |
try: |
df = pd.read_excel(path_to_file) |
return df |
except ValueError: |
try: |
df = pd.read_csv(path_to_file) |
return df |
except ValueError: |
logging.exception("### Error occurred while splitting document. Info: " + str(e)) |
exit() |
def run(): |
fire.Fire(data_extraction) |
if __name__ == '__main__': |
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(filename)s - %(message)s",datefmt="%d/%m/%Y %H:%M:%S",level=logging.INFO) |
run() |