|
|
|
""" Download pre-trained models from Google drive. """ |
|
import os |
|
import argparse |
|
import zipfile |
|
import logging |
|
import requests |
|
from tqdm import tqdm |
|
import fire |
|
import re |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(filename)s - %(message)s", |
|
datefmt="%d/%m/%Y %H:%M:%S", |
|
level=logging.INFO) |
|
|
|
|
|
"", "", "", "","","" |
|
|
|
|
|
MODEL_TO_URL = { |
|
|
|
'PathologyEmoryPubMedBERT': 'https://drive.google.com/open?id=1l_el_mYXoTIQvGwKN2NZbp97E4svH4Fh', |
|
'PathologyEmoryBERT': 'https://drive.google.com/open?id=11vzo6fJBw1RcdHVBAh6nnn8yua-4kj2IX', |
|
'ClinicalBERT': 'https://drive.google.com/open?id=1UK9HqSspVneK8zGg7B93vIdTGKK9MI_v', |
|
'BlueBERT': 'https://drive.google.com/open?id=1o-tcItErOiiwqZ-YRa3sMM3hGB4d3WkP', |
|
'BioBERT': 'https://drive.google.com/open?id=1m7EkWkFBIBuGbfwg7j0R_WINNnYk3oS9', |
|
'BERT': 'https://drive.google.com/open?id=1SB_AQAAsHkF79iSAaB3kumYT1rwcOJru', |
|
|
|
'single_tfidf': 'https://drive.google.com/open?id=1-hxf7sKRtFGMOenlafdkeAr8_9pOz6Ym', |
|
'branch_tfidf': 'https://drive.google.com/open?id=1pDSnwLFn3YzPRac9rKFV_FN9kdzj2Lb0' |
|
} |
|
|
|
""" |
|
For large Files, Drive requires a Virus Check. |
|
This function is reponsivle to extract the link from the button confirmation |
|
""" |
|
def get_url_from_gdrive_confirmation(contents): |
|
url = "" |
|
for line in contents.splitlines(): |
|
m = re.search(r'href="(\/uc\?export=download[^"]+)', line) |
|
if m: |
|
url = "https://docs.google.com" + m.groups()[0] |
|
url = url.replace("&", "&") |
|
break |
|
m = re.search('id="downloadForm" action="(.+?)"', line) |
|
if m: |
|
url = m.groups()[0] |
|
url = url.replace("&", "&") |
|
break |
|
m = re.search('"downloadUrl":"([^"]+)', line) |
|
if m: |
|
url = m.groups()[0] |
|
url = url.replace("\\u003d", "=") |
|
url = url.replace("\\u0026", "&") |
|
break |
|
m = re.search('<p class="uc-error-subcaption">(.*)</p>', line) |
|
if m: |
|
error = m.groups()[0] |
|
raise RuntimeError(error) |
|
if not url: |
|
return None |
|
return url |
|
|
|
def download_file_from_google_drive(id, destination): |
|
URL = "https://docs.google.com/uc?export=download" |
|
|
|
session = requests.Session() |
|
|
|
|
|
response = session.get(URL, params={ 'id' : id }, stream=True) |
|
URL_new = get_url_from_gdrive_confirmation(response.text) |
|
|
|
if URL_new != None: |
|
URL = URL_new |
|
response = session.get(URL, params={ 'id' : id }, stream=True) |
|
|
|
token = get_confirm_token(response) |
|
|
|
if token: |
|
params = { 'id' : id, 'confirm' : token } |
|
response = session.get(URL, params=params, stream=True) |
|
|
|
save_response_content(response, destination) |
|
|
|
def get_confirm_token(response): |
|
for key, value in response.cookies.items(): |
|
if key.startswith('download_warning'): |
|
return value |
|
|
|
return None |
|
|
|
def save_response_content(response, destination): |
|
CHUNK_SIZE = 32768 |
|
|
|
with open(destination, "wb") as f: |
|
for chunk in tqdm(response.iter_content(CHUNK_SIZE)): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
def check_if_exist(model:str = "single_tfidf"): |
|
|
|
if model =="single_vectorizer": |
|
model = "single_tfidf" |
|
if model =="branch_vectorizer": |
|
model = "branch_tfidf" |
|
|
|
project_dir = os.path.dirname(os.path.abspath(__file__)) |
|
if model != None: |
|
if model in ['single_tfidf', 'branch_tfidf' ]: |
|
path='models/all_labels_hierarchy/' |
|
path_model = os.path.join(project_dir, path, model,'classifiers') |
|
path_vectorizer = os.path.join(project_dir, path, model,'vectorizers') |
|
if os.path.exists(path_model) and os.path.exists(path_vectorizer): |
|
if len(os.listdir(path_model)) >0 and len(os.listdir(path_vectorizer)) >0: |
|
return True |
|
else: |
|
path='models/higher_order_hierarchy/' |
|
path_folder = os.path.join(project_dir, path, model) |
|
if os.path.exists(path_folder): |
|
if len(os.listdir(path_folder + "/" )) >1: |
|
return True |
|
return False |
|
|
|
def download_model(all_labels='single_tfidf', higher_order='PathologyEmoryPubMedBERT'): |
|
project_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
path_all_labels='models/all_labels_hierarchy/' |
|
path_higher_order='models/higher_order_hierarchy/' |
|
|
|
def extract_model(path_file, name): |
|
|
|
os.makedirs(os.path.join(project_dir, path_file), exist_ok=True) |
|
|
|
file_destination = os.path.join(project_dir, path_file, name + '.zip') |
|
|
|
file_id = MODEL_TO_URL[name].split('id=')[-1] |
|
|
|
logging.info(f'Downloading {name} model (~1000MB tar.xz archive)') |
|
download_file_from_google_drive(file_id, file_destination) |
|
|
|
logging.info('Extracting model from archive (~1300MB folder) and saving to ' + str(file_destination)) |
|
with zipfile.ZipFile(file_destination, 'r') as zip_ref: |
|
zip_ref.extractall(path=os.path.dirname(file_destination)) |
|
|
|
logging.info('Removing archive') |
|
os.remove(file_destination) |
|
logging.info('Done.') |
|
|
|
|
|
if higher_order != None: |
|
if not check_if_exist(higher_order): |
|
extract_model(path_higher_order, higher_order) |
|
else: |
|
logging.info('Model ' + str(higher_order) + ' already exist') |
|
|
|
if all_labels!= None: |
|
if not check_if_exist(all_labels): |
|
extract_model(path_all_labels, all_labels) |
|
else: |
|
logging.info('Model ' + str(all_labels) + ' already exist') |
|
|
|
|
|
|
|
|
|
def download(all_labels:str = "single_tfidf", higher_order:str = "PathologyEmoryPubMedBERT"): |
|
""" |
|
Input Options: |
|
all_labels : single_tfidf, branch_tfidf |
|
higher_order : clinicalBERT, blueBERT, patho_clinicalBERT, patho_blueBERT, charBERT |
|
""" |
|
all_labels_options = [ "single_tfidf", "branch_tfidf"] |
|
higher_order_option = [ "PathologyEmoryPubMedBERT", "PathologyEmoryBERT", "ClinicalBERT", "BlueBERT","BioBERT","BERT" ] |
|
|
|
if all_labels not in all_labels_options or higher_order not in higher_order_option: |
|
print("\n\tPlease provide a valid model for downloading") |
|
print("\n\t\tall_labels: " + " ".join(x for x in all_labels_options)) |
|
print("\n\t\thigher_order: " + " ".join(x for x in higher_order)) |
|
exit() |
|
|
|
download_model(all_labels,higher_order) |
|
|
|
if __name__ == "__main__": |
|
fire.Fire(download) |
|
|
|
|
|
|
|
|