Spaces:
Running
on
Zero
Running
on
Zero
from main import * | |
def default_run(): | |
# Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
config_location = "configs/main.yaml" | |
config = workspace.load_config(config_location, None) | |
if os.getenv("LOCAL_RANK", '0') == '0': | |
config = workspace.create_workspace(config) | |
# Run the experiment | |
run_experiment(config) | |
def with_mast3r_loss(): | |
# Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
config_location = "configs/with_mast3r_loss.yaml" | |
config = workspace.load_config(config_location, None) | |
if os.getenv("LOCAL_RANK", '0') == '0': | |
config = workspace.create_workspace(config) | |
# Run the experiment | |
run_experiment(config) | |
def without_masking(): | |
# Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
config_location = "configs/without_masking.yaml" | |
config = workspace.load_config(config_location, None) | |
if os.getenv("LOCAL_RANK", '0') == '0': | |
config = workspace.create_workspace(config) | |
# Run the experiment | |
run_experiment(config) | |
def without_lpips_loss(): | |
# Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
config_location = "configs/without_lpips_loss.yaml" | |
config = workspace.load_config(config_location, None) | |
if os.getenv("LOCAL_RANK", '0') == '0': | |
config = workspace.create_workspace(config) | |
# Run the experiment | |
run_experiment(config) | |
def without_offset(): | |
# Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
config_location = "configs/without_offset.yaml" | |
config = workspace.load_config(config_location, None) | |
if os.getenv("LOCAL_RANK", '0') == '0': | |
config = workspace.create_workspace(config) | |
# Run the experiment | |
run_experiment(config) | |
if __name__ == "__main__": | |
# Somewhat hacky way to fetch the function corresponding to the ablation we want to run | |
ablation_name = sys.argv[1] | |
ablation_function = locals().get(ablation_name) | |
# Run the ablation if it exists | |
if ablation_function: | |
ablation_function() | |
else: | |
raise NotImplementedError( | |
f"Ablation name '{sys.argv[1]}' not recognised") | |