File size: 6,022 Bytes
05b0e9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

""" Download pre-trained models from Google drive. """
import os
import argparse
import zipfile 
import logging
import requests
from tqdm import tqdm
import fire
import re 

logging.basicConfig(
		format="%(asctime)s - %(levelname)s - %(filename)s -   %(message)s",
		datefmt="%d/%m/%Y %H:%M:%S",
		level=logging.INFO)


"", "", "", "","","" 


MODEL_TO_URL = {
	
	'PathologyEmoryPubMedBERT': 'https://drive.google.com/open?id=1l_el_mYXoTIQvGwKN2NZbp97E4svH4Fh', 
	'PathologyEmoryBERT': 'https://drive.google.com/open?id=11vzo6fJBw1RcdHVBAh6nnn8yua-4kj2IX', 
	'ClinicalBERT': 'https://drive.google.com/open?id=1UK9HqSspVneK8zGg7B93vIdTGKK9MI_v', 
	'BlueBERT': 'https://drive.google.com/open?id=1o-tcItErOiiwqZ-YRa3sMM3hGB4d3WkP', 
	'BioBERT': 'https://drive.google.com/open?id=1m7EkWkFBIBuGbfwg7j0R_WINNnYk3oS9',
	'BERT': 'https://drive.google.com/open?id=1SB_AQAAsHkF79iSAaB3kumYT1rwcOJru',

	'single_tfidf': 'https://drive.google.com/open?id=1-hxf7sKRtFGMOenlafdkeAr8_9pOz6Ym',
	'branch_tfidf': 'https://drive.google.com/open?id=1pDSnwLFn3YzPRac9rKFV_FN9kdzj2Lb0'
}

"""
	For large Files, Drive requires a Virus Check.
	This function is reponsivle to extract the link from the button confirmation
"""
def get_url_from_gdrive_confirmation(contents):
    url = ""
    for line in contents.splitlines():
        m = re.search(r'href="(\/uc\?export=download[^"]+)', line)
        if m:
            url = "https://docs.google.com" + m.groups()[0]
            url = url.replace("&", "&")
            break
        m = re.search('id="downloadForm" action="(.+?)"', line)
        if m:
            url = m.groups()[0]
            url = url.replace("&", "&")
            break
        m = re.search('"downloadUrl":"([^"]+)', line)
        if m:
            url = m.groups()[0]
            url = url.replace("\\u003d", "=")
            url = url.replace("\\u0026", "&")
            break
        m = re.search('<p class="uc-error-subcaption">(.*)</p>', line)
        if m:
            error = m.groups()[0]
            raise RuntimeError(error)
    if not url:
        return None
    return url

def download_file_from_google_drive(id, destination):
	URL = "https://docs.google.com/uc?export=download"

	session = requests.Session()

	
	response = session.get(URL, params={ 'id' : id }, stream=True)
	URL_new = get_url_from_gdrive_confirmation(response.text)

	if URL_new != None:
		URL = URL_new
		response = session.get(URL, params={ 'id' : id }, stream=True)

	token = get_confirm_token(response)

	if token:
		params = { 'id' : id, 'confirm' : token }
		response = session.get(URL, params=params, stream=True)

	save_response_content(response, destination)

def get_confirm_token(response):
	for key, value in response.cookies.items():
		if key.startswith('download_warning'):
			return value

	return None

def save_response_content(response, destination):
	CHUNK_SIZE = 32768

	with open(destination, "wb") as f:
		for chunk in tqdm(response.iter_content(CHUNK_SIZE)):
			if chunk: # filter out keep-alive new chunks
				f.write(chunk)

def check_if_exist(model:str = "single_tfidf"):

	if model =="single_vectorizer":
		model = "single_tfidf"
	if model =="branch_vectorizer":
		model = "branch_tfidf"

	project_dir = os.path.dirname(os.path.abspath(__file__))
	if model != None:
		if model in ['single_tfidf', 'branch_tfidf' ]:
			path='models/all_labels_hierarchy/'
			path_model = os.path.join(project_dir, path,  model,'classifiers')
			path_vectorizer = os.path.join(project_dir, path,  model,'vectorizers')
			if os.path.exists(path_model) and os.path.exists(path_vectorizer):
				if len(os.listdir(path_model)) >0 and len(os.listdir(path_vectorizer)) >0:
					return True
		else:
			path='models/higher_order_hierarchy/'
			path_folder = os.path.join(project_dir, path,  model)
			if os.path.exists(path_folder):
				if len(os.listdir(path_folder + "/" )) >1:
					return True
	return False

def download_model(all_labels='single_tfidf', higher_order='PathologyEmoryPubMedBERT'):
	project_dir = os.path.dirname(os.path.abspath(__file__))

	path_all_labels='models/all_labels_hierarchy/'
	path_higher_order='models/higher_order_hierarchy/'
	
	def extract_model(path_file, name):

		os.makedirs(os.path.join(project_dir, path_file), exist_ok=True)

		file_destination = os.path.join(project_dir, path_file, name + '.zip')

		file_id = MODEL_TO_URL[name].split('id=')[-1]

		logging.info(f'Downloading {name} model (~1000MB tar.xz archive)')
		download_file_from_google_drive(file_id, file_destination)

		logging.info('Extracting model from archive (~1300MB folder) and saving to ' + str(file_destination))
		with zipfile.ZipFile(file_destination, 'r') as zip_ref:
			zip_ref.extractall(path=os.path.dirname(file_destination))

		logging.info('Removing archive')
		os.remove(file_destination)
		logging.info('Done.')


	if higher_order != None:
		if not check_if_exist(higher_order):
			extract_model(path_higher_order, higher_order)
		else:
			logging.info('Model ' + str(higher_order) + ' already exist')

	if all_labels!= None:
		if not check_if_exist(all_labels):
			extract_model(path_all_labels, all_labels)
		else:
			logging.info('Model ' + str(all_labels) + ' already exist')

	
	

def download(all_labels:str = "single_tfidf", higher_order:str = "PathologyEmoryPubMedBERT"):
	"""
		Input Options:
			all_labels : single_tfidf, branch_tfidf
			higher_order : clinicalBERT, blueBERT, patho_clinicalBERT, patho_blueBERT, charBERT
	"""
	all_labels_options  = [ "single_tfidf", "branch_tfidf"]
	higher_order_option = [ "PathologyEmoryPubMedBERT", "PathologyEmoryBERT", "ClinicalBERT", "BlueBERT","BioBERT","BERT" ]

	if all_labels not in all_labels_options or higher_order not in higher_order_option:
		print("\n\tPlease provide a valid model for downloading")
		print("\n\t\tall_labels: " + " ".join(x for x in all_labels_options))
		print("\n\t\thigher_order: " + " ".join(x for x in higher_order))
		exit()

	download_model(all_labels,higher_order)

if __name__ == "__main__":
	fire.Fire(download)