File size: 2,986 Bytes
83c8e0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import json
import os
import sys
import urllib
from pprint import pprint
import wget
from tqdm import tqdm
from UVR import DEMUCS_NEWER_REPO_DIR, VR_MODELS_DIR, MDX_MODELS_DIR
from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO
online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS))
mdx_download_list = {
**online_model_data["mdx_download_list"],
**online_model_data["mdx23c_download_list"],
**online_model_data["mdx23_download_list"],
# **online_model_data["mdx_download_vip_list"],
# **online_model_data["mdx23c_download_vip_list"],
}
vr_download_list = online_model_data["vr_download_list"]
demucs_download_list = online_model_data["demucs_download_list"]
def get_mdx_model_file(model):
return get_mdx_model_filelist(model)[0][0]
def get_mdx_model_filelist(model):
filename = mdx_download_list[model]
if isinstance(filename, dict):
model_name = list(filename.keys())[0]
else:
model_name = str(filename)
model_path = os.path.join(MDX_MODELS_DIR, model_name)
url = f"{NORMAL_REPO}{model_name}"
return [(model_path, url)]
def get_vr_model_file(model):
return get_vr_model_filelist(model)[0][0]
def get_vr_model_filelist(model):
filename = vr_download_list[model]
url = f"{NORMAL_REPO}{filename}"
model_path = os.path.join(VR_MODELS_DIR, filename)
return [(model_path, url)]
def get_demucs_model_file(model):
for filename, url in get_demucs_model_filelist(model):
if filename.lower().endswith('.yaml'):
return filename
def get_demucs_model_filelist(model):
download_demucs_newer_models = []
for filename, url in demucs_download_list[model].items():
model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename)
download_demucs_newer_models.append((model_path, url))
return download_demucs_newer_models
def get_model_file(model_name):
if model_name in mdx_download_list:
model_path = get_mdx_model_file(model_name)
elif model_name in vr_download_list:
model_path = get_vr_model_file(model_name)
elif model_name in demucs_download_list:
model_path = get_demucs_model_file(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
return model_path
def download_model(model_name):
if model_name in mdx_download_list:
filelist = get_mdx_model_filelist(model_name)
elif model_name in vr_download_list:
filelist = get_vr_model_filelist(model_name)
elif model_name in demucs_download_list:
filelist = get_demucs_model_filelist(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
for model_path, url in filelist:
if os.path.isfile(model_path):
return
print(f'Downloading from {url} to {model_path}')
wget.download(url, model_path)
if __name__ == '__main__':
model_name = sys.argv[1]
download_model(model_name)
|