File size: 5,905 Bytes
6b767f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import pandas as pd
import numpy as np
from tqdm import tqdm
import re
import fire
import json
from tqdm import tqdm
import logging
from pipeline import Pipeline 
import copy
from download_models import check_if_exist

"""
    Install dependecies by running: pip3 install -r requirements.txt

    Running command example:
    python3 label_extraction.py --path_to_file data.xlsx --column_name report --save_predictions predictions.xlsx --save_json output.json
"""

def data_extraction(path_to_file:str, column_name:str, higher_model:str="clinicalBERT", all_label_model="single_tfidf", save_predictions:str=None, output_model_data=None,save_input=None, save_json:str=None):

    """
        This program takes an excell/csv sheet and extract the higher order and cancer characteristics from pathology reports 

        Input Options:
            1) path_to_file - Path to an excel/csv with pathology diagnosis: String (Required)
            2) column_name - Which column has the pathology diagnosis: String (Required)
            3) higher_model - Which version of higher order model to use: String (Required)
            4) all_label_model - Which version of all labels model to use: String (Required)
            5) save_predictions - Path to save output: String (Optional)
            6) output_model_data - Option to output model data to csv True/False (Optional)
            7) save_input - Option to output the input fields True/False (Optional)
            8) save_json - Path to save json analyis: String (Optional)
            

    """

    data_orig = read_data(path_to_file)
    data_orig = data_orig.fillna("NA")
    data = data_orig.loc[:, ~data_orig.columns.str.contains('^Unnamed')][column_name].values
    
    predictions, json_output, higher_order_pred,all_labels_pred = {},[],[],[]

    if not check_if_exist(higher_model):
        print("\n\t ##### Please Download Model: " + str(higher_model) + "#####")
        exit()
    if not check_if_exist(all_label_model):
        print("\n\t ##### Please Download Model: " + str(all_label_model) + "#####")
        exit()

    model = Pipeline(bert_option=higher_model, branch_option=all_label_model)

    logging.info("\nRunning Predictions for data size of: " + str(len(data)))
    for index in tqdm(range(len(data))):
        d = data[index]
        # refactor json
        preds,all_layer_hidden_states = model.run(d)
        predictions["sample_" + str(index)] = {}
        for ind,pred in enumerate(preds):
            predictions["sample_" + str(index)]["prediction_" + str(ind)] = pred

    for key,sample in predictions.items():
        higher,all_p = [],[]
        for key,pred in sample.items():
            for higher_order, sub_arr in pred.items():
                higher.append(higher_order)
                for label,v in sub_arr['labels'].items():
                    all_p.append(label)  

        higher_order_pred.append(" && ".join(x for x in higher))
        all_labels_pred.append(" && ".join(x for x in all_p))


    predictions_refact = copy.deepcopy(predictions)
    transformer_data, discriminator_data= [0 for x in range(len(data))], [0 for x in range(len(data))]

    for index in tqdm(range(len(data))):
        key = "sample_" + str(index)
        for k,v in predictions[key].items():
            for k_s, v_s in v.items():
                predictions_refact["sample_" + str(index)]["data"] = v_s['data']
                predictions_refact["sample_" + str(index)]["transformer_data"] = v_s['transformer_data']
                predictions_refact["sample_" + str(index)]["discriminator_data"] = v_s['word_analysis']['discriminator_data']
                transformer_data[index] = v_s['transformer_data']
                discriminator_data[index] = v_s['word_analysis']['discriminator_data']

                del predictions_refact[key][k][k_s]['data']
                del predictions_refact[key][k][k_s]['transformer_data']
                del predictions_refact[key][k][k_s]['word_analysis']['discriminator_data']

    json_output = predictions_refact
    

    if save_predictions!= None:
        logging.info("Saving Predictions")
        if output_model_data != None:
            all_preds = pd.DataFrame(list(zip(higher_order_pred, all_labels_pred,transformer_data,discriminator_data,data)), columns =['Higher Order',"All Labels", 'Higher Order Model Data','All Labels Model Data',column_name])
        else:
            all_preds = pd.DataFrame(list(zip(higher_order_pred, all_labels_pred)), columns =['Higher Order',"All Labels"])

        if save_input != None:
            all_preds = pd.concat([data_orig, all_preds], axis=1)
        try:
            all_preds.to_excel(save_predictions)
        except ValueError:
            try:
                all_preds.to_csv(save_predictions)
            except ValueError:
                logging.exception("Error while saving predictions " + str(e))
                exit()
        logging.info("Done")

    if save_json!= None:
        logging.info("Saving Json")
        try:
            with open(save_json, 'w') as f:
                for k, v in json_output.items():
                    f.write('{'+str(k) + ':'+ str(v) + '\n')
        
        except ValueError:
            logging.exception("Error while saving json analysis " + str(e))
            exit()
        logging.info("Done")


def read_data(path_to_file):

    try:
        df = pd.read_excel(path_to_file)
        return df
    except ValueError:
        try:
            df = pd.read_csv(path_to_file)
            return df
        except ValueError:
            logging.exception("### Error occurred while splitting document. Info: " + str(e))
            exit()



def run():
    fire.Fire(data_extraction)

if __name__ == '__main__':
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(filename)s -   %(message)s",datefmt="%d/%m/%Y %H:%M:%S",level=logging.INFO)
    run()