Thiago
first commit
6b767f8
raw
history blame
5.91 kB
import pandas as pd
import numpy as np
from tqdm import tqdm
import re
import fire
import json
from tqdm import tqdm
import logging
from pipeline import Pipeline
import copy
from download_models import check_if_exist
"""
Install dependecies by running: pip3 install -r requirements.txt
Running command example:
python3 label_extraction.py --path_to_file data.xlsx --column_name report --save_predictions predictions.xlsx --save_json output.json
"""
def data_extraction(path_to_file:str, column_name:str, higher_model:str="clinicalBERT", all_label_model="single_tfidf", save_predictions:str=None, output_model_data=None,save_input=None, save_json:str=None):
"""
This program takes an excell/csv sheet and extract the higher order and cancer characteristics from pathology reports
Input Options:
1) path_to_file - Path to an excel/csv with pathology diagnosis: String (Required)
2) column_name - Which column has the pathology diagnosis: String (Required)
3) higher_model - Which version of higher order model to use: String (Required)
4) all_label_model - Which version of all labels model to use: String (Required)
5) save_predictions - Path to save output: String (Optional)
6) output_model_data - Option to output model data to csv True/False (Optional)
7) save_input - Option to output the input fields True/False (Optional)
8) save_json - Path to save json analyis: String (Optional)
"""
data_orig = read_data(path_to_file)
data_orig = data_orig.fillna("NA")
data = data_orig.loc[:, ~data_orig.columns.str.contains('^Unnamed')][column_name].values
predictions, json_output, higher_order_pred,all_labels_pred = {},[],[],[]
if not check_if_exist(higher_model):
print("\n\t ##### Please Download Model: " + str(higher_model) + "#####")
exit()
if not check_if_exist(all_label_model):
print("\n\t ##### Please Download Model: " + str(all_label_model) + "#####")
exit()
model = Pipeline(bert_option=higher_model, branch_option=all_label_model)
logging.info("\nRunning Predictions for data size of: " + str(len(data)))
for index in tqdm(range(len(data))):
d = data[index]
# refactor json
preds,all_layer_hidden_states = model.run(d)
predictions["sample_" + str(index)] = {}
for ind,pred in enumerate(preds):
predictions["sample_" + str(index)]["prediction_" + str(ind)] = pred
for key,sample in predictions.items():
higher,all_p = [],[]
for key,pred in sample.items():
for higher_order, sub_arr in pred.items():
higher.append(higher_order)
for label,v in sub_arr['labels'].items():
all_p.append(label)
higher_order_pred.append(" && ".join(x for x in higher))
all_labels_pred.append(" && ".join(x for x in all_p))
predictions_refact = copy.deepcopy(predictions)
transformer_data, discriminator_data= [0 for x in range(len(data))], [0 for x in range(len(data))]
for index in tqdm(range(len(data))):
key = "sample_" + str(index)
for k,v in predictions[key].items():
for k_s, v_s in v.items():
predictions_refact["sample_" + str(index)]["data"] = v_s['data']
predictions_refact["sample_" + str(index)]["transformer_data"] = v_s['transformer_data']
predictions_refact["sample_" + str(index)]["discriminator_data"] = v_s['word_analysis']['discriminator_data']
transformer_data[index] = v_s['transformer_data']
discriminator_data[index] = v_s['word_analysis']['discriminator_data']
del predictions_refact[key][k][k_s]['data']
del predictions_refact[key][k][k_s]['transformer_data']
del predictions_refact[key][k][k_s]['word_analysis']['discriminator_data']
json_output = predictions_refact
if save_predictions!= None:
logging.info("Saving Predictions")
if output_model_data != None:
all_preds = pd.DataFrame(list(zip(higher_order_pred, all_labels_pred,transformer_data,discriminator_data,data)), columns =['Higher Order',"All Labels", 'Higher Order Model Data','All Labels Model Data',column_name])
else:
all_preds = pd.DataFrame(list(zip(higher_order_pred, all_labels_pred)), columns =['Higher Order',"All Labels"])
if save_input != None:
all_preds = pd.concat([data_orig, all_preds], axis=1)
try:
all_preds.to_excel(save_predictions)
except ValueError:
try:
all_preds.to_csv(save_predictions)
except ValueError:
logging.exception("Error while saving predictions " + str(e))
exit()
logging.info("Done")
if save_json!= None:
logging.info("Saving Json")
try:
with open(save_json, 'w') as f:
for k, v in json_output.items():
f.write('{'+str(k) + ':'+ str(v) + '\n')
except ValueError:
logging.exception("Error while saving json analysis " + str(e))
exit()
logging.info("Done")
def read_data(path_to_file):
try:
df = pd.read_excel(path_to_file)
return df
except ValueError:
try:
df = pd.read_csv(path_to_file)
return df
except ValueError:
logging.exception("### Error occurred while splitting document. Info: " + str(e))
exit()
def run():
fire.Fire(data_extraction)
if __name__ == '__main__':
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(filename)s - %(message)s",datefmt="%d/%m/%Y %H:%M:%S",level=logging.INFO)
run()