|
import json |
|
import os |
|
import sys |
|
import urllib |
|
from pprint import pprint |
|
|
|
import wget |
|
from tqdm import tqdm |
|
|
|
from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO |
|
|
|
BASE_PATH = os.path.dirname(os.path.abspath(__file__)) |
|
MODELS_DIR = os.path.join(BASE_PATH, 'models') |
|
VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models') |
|
MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models') |
|
DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models') |
|
DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo') |
|
|
|
online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS)) |
|
mdx_download_list = { |
|
**online_model_data["mdx_download_list"], |
|
**online_model_data["mdx23c_download_list"], |
|
**online_model_data["mdx23_download_list"], |
|
|
|
|
|
} |
|
vr_download_list = online_model_data["vr_download_list"] |
|
demucs_download_list = online_model_data["demucs_download_list"] |
|
|
|
|
|
def get_mdx_model_file(model): |
|
return get_mdx_model_filelist(model)[0][0] |
|
|
|
|
|
def get_mdx_model_filelist(model): |
|
filename = mdx_download_list[model] |
|
if isinstance(filename, dict): |
|
model_name = list(filename.keys())[0] |
|
else: |
|
model_name = str(filename) |
|
model_path = os.path.join(MDX_MODELS_DIR, model_name) |
|
url = f"{NORMAL_REPO}{model_name}" |
|
|
|
return [(model_path, url)] |
|
|
|
|
|
def get_vr_model_file(model): |
|
return get_vr_model_filelist(model)[0][0] |
|
|
|
|
|
def get_vr_model_filelist(model): |
|
filename = vr_download_list[model] |
|
url = f"{NORMAL_REPO}{filename}" |
|
model_path = os.path.join(VR_MODELS_DIR, filename) |
|
return [(model_path, url)] |
|
|
|
|
|
def get_demucs_model_file(model): |
|
for filename, url in get_demucs_model_filelist(model): |
|
if filename.lower().endswith('.yaml'): |
|
return filename |
|
|
|
|
|
def get_demucs_model_filelist(model): |
|
download_demucs_newer_models = [] |
|
for filename, url in demucs_download_list[model].items(): |
|
model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename) |
|
download_demucs_newer_models.append((model_path, url)) |
|
return download_demucs_newer_models |
|
|
|
|
|
def get_model_file(model_name): |
|
if model_name in mdx_download_list: |
|
model_path = get_mdx_model_file(model_name) |
|
elif model_name in vr_download_list: |
|
model_path = get_vr_model_file(model_name) |
|
elif model_name in demucs_download_list: |
|
model_path = get_demucs_model_file(model_name) |
|
else: |
|
raise FileNotFoundError(f"Can't found model {model_name}") |
|
return model_path |
|
|
|
|
|
def download_model(model_name): |
|
if model_name in mdx_download_list: |
|
filelist = get_mdx_model_filelist(model_name) |
|
elif model_name in vr_download_list: |
|
filelist = get_vr_model_filelist(model_name) |
|
elif model_name in demucs_download_list: |
|
filelist = get_demucs_model_filelist(model_name) |
|
else: |
|
raise FileNotFoundError(f"Can't found model {model_name}") |
|
|
|
for model_path, url in filelist: |
|
if os.path.isfile(model_path): |
|
return |
|
print(f'Downloading from {url} to {model_path}') |
|
wget.download(url, model_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
model_name = sys.argv[1] |
|
download_model(model_name) |
|
|