counterfactual-world-models / cwm /eval /Flow /create_spring_submission_unified.py
rahulvenkk
app.py updated
6dfcb0f
import argparse
# Parse command-line arguments
import importlib
import time
parser = argparse.ArgumentParser(description='Process a folder with RAFT')
parser.add_argument('--folder', type=str, required=True, help='Folder to process')
parser.add_argument('--model', type=str, required=True, help='Model used to extract flow')
parser.add_argument('--save_data_path', type=str, required=True, help='where to save the data')
parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
args = parser.parse_args()
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
import torch
torch.cuda.set_device(0)
import h5py
def writeFlo5File(flow, filename):
with h5py.File(filename, "w") as f:
f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
if __name__ == '__main__':
module_name, class_name = args.model.rsplit(".", 1)
module = importlib.import_module(module_name)
model = getattr(module, class_name)
model = model().cuda().eval()
folder = args.folder.split('/')[-1]
import os
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
# import smurf # Assuming this is your custom inference module
# Path for the dataset
dataset_path = '/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/'
save_data_path = args.save_data_path
if not os.path.exists(save_data_path):
os.makedirs(save_data_path)
resize_crop = transforms.Compose([
transforms.ToTensor(),
])
import numpy as np
def l2norm(x):
return np.sqrt((x ** 2).sum(-1))
all_epe = []
# Create a new HDF5 file
TAG_FLOAT = 202021.25
# Iterate over each folder in the dataset directory
for dir in ['FW', 'BW']:
for stereo in ['left', 'right']:
files = sorted(os.listdir(os.path.join(dataset_path, folder, f'frame_{stereo}')))
output_folder = os.path.join(save_data_path, folder)
output_folder = os.path.join(output_folder, f'flow_{dir}_{stereo}')
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for ct_f in range(len(files) - 1):
# Read images
if dir == 'FW':
f1 = files[ct_f]
f2 = files[ct_f + 1]
else:
f2 = files[ct_f]
f1 = files[ct_f + 1]
t = time.time()
image1_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f1)
image2_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f2)
idx = image1_path.split('/')[-1].split('.')[0].split('_')[-1]
flow_save_path = os.path.join(output_folder, f'flow_{dir}_{stereo}_' + idx + '.flo5')
# if os.path.exists(flow_save_path):
# try:
# with h5py.File(flow_save_path, 'r+') as f:
# if f['flow'][:].shape[0] == 2:
# flow = f['flow'][:].transpose([1, 2, 0])
# del f['flow']
# f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
# continue
# else:
# continue
# except:
# pass
image1_ = plt.imread(image1_path)
image2_ = plt.imread(image2_path)
image1 = resize_crop(image1_)
image2 = resize_crop(image2_)
forward_flow = model.forward(image1, image2)
writeFlo5File(forward_flow, flow_save_path)