Spaces:
Paused
Paused
import os | |
import torch | |
import PIL.Image | |
import numpy as np | |
from torch import nn | |
import torch.distributed as dist | |
import timm.models.hub as timm_hub | |
"""Modified from https://github.com/CompVis/taming-transformers.git""" | |
import hashlib | |
import requests | |
from tqdm import tqdm | |
try: | |
import piq | |
except: | |
pass | |
_CONTEXT_PARALLEL_GROUP = None | |
_CONTEXT_PARALLEL_SIZE = None | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_world_size(): | |
if not is_dist_avail_and_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process(): | |
return get_rank() == 0 | |
def is_context_parallel_initialized(): | |
if _CONTEXT_PARALLEL_GROUP is None: | |
return False | |
else: | |
return True | |
def set_context_parallel_group(size, group): | |
global _CONTEXT_PARALLEL_GROUP | |
global _CONTEXT_PARALLEL_SIZE | |
_CONTEXT_PARALLEL_GROUP = group | |
_CONTEXT_PARALLEL_SIZE = size | |
def initialize_context_parallel(context_parallel_size): | |
global _CONTEXT_PARALLEL_GROUP | |
global _CONTEXT_PARALLEL_SIZE | |
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" | |
_CONTEXT_PARALLEL_SIZE = context_parallel_size | |
rank = torch.distributed.get_rank() | |
world_size = torch.distributed.get_world_size() | |
for i in range(0, world_size, context_parallel_size): | |
ranks = range(i, i + context_parallel_size) | |
group = torch.distributed.new_group(ranks) | |
if rank in ranks: | |
_CONTEXT_PARALLEL_GROUP = group | |
break | |
def get_context_parallel_group(): | |
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" | |
return _CONTEXT_PARALLEL_GROUP | |
def get_context_parallel_world_size(): | |
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
return _CONTEXT_PARALLEL_SIZE | |
def get_context_parallel_rank(): | |
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
rank = get_rank() | |
cp_rank = rank % _CONTEXT_PARALLEL_SIZE | |
return cp_rank | |
def get_context_parallel_group_rank(): | |
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
rank = get_rank() | |
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE | |
return cp_group_rank | |
def download_cached_file(url, check_hash=True, progress=False): | |
""" | |
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. | |
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. | |
""" | |
def get_cached_file_path(): | |
# a hack to sync the file path across processes | |
parts = torch.hub.urlparse(url) | |
filename = os.path.basename(parts.path) | |
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) | |
return cached_file | |
if is_main_process(): | |
timm_hub.download_cached_file(url, check_hash, progress) | |
if is_dist_avail_and_initialized(): | |
dist.barrier() | |
return get_cached_file_path() | |
def convert_weights_to_fp16(model: nn.Module): | |
"""Convert applicable model parameters to fp16""" | |
def _convert_weights_to_fp16(l): | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |
l.weight.data = l.weight.data.to(torch.float16) | |
if l.bias is not None: | |
l.bias.data = l.bias.data.to(torch.float16) | |
model.apply(_convert_weights_to_fp16) | |
def convert_weights_to_bf16(model: nn.Module): | |
"""Convert applicable model parameters to fp16""" | |
def _convert_weights_to_bf16(l): | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |
l.weight.data = l.weight.data.to(torch.bfloat16) | |
if l.bias is not None: | |
l.bias.data = l.bias.data.to(torch.bfloat16) | |
model.apply(_convert_weights_to_bf16) | |
def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'): | |
import json | |
import jsonlines | |
print("Dump result") | |
# Make the temp dir for saving results | |
if not os.path.exists(result_dir): | |
if is_main_process(): | |
os.makedirs(result_dir) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.barrier() | |
result_file = os.path.join( | |
result_dir, "%s_rank%d.json" % (filename, get_rank()) | |
) | |
final_result_file = os.path.join(result_dir, f"{filename}.{save_format}") | |
json.dump(result, open(result_file, "w")) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.barrier() | |
if is_main_process(): | |
# print("rank %d starts merging results." % get_rank()) | |
# combine results from all processes | |
result = [] | |
for rank in range(get_world_size()): | |
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank)) | |
res = json.load(open(result_file, "r")) | |
result += res | |
# print("Remove duplicate") | |
if remove_duplicate: | |
result_new = [] | |
id_set = set() | |
for res in result: | |
if res[remove_duplicate] not in id_set: | |
id_set.add(res[remove_duplicate]) | |
result_new.append(res) | |
result = result_new | |
if save_format == 'json': | |
json.dump(result, open(final_result_file, "w")) | |
else: | |
assert save_format == 'jsonl', "Only support json adn jsonl format" | |
with jsonlines.open(final_result_file, "w") as writer: | |
writer.write_all(result) | |
# print("result file saved to %s" % final_result_file) | |
return final_result_file | |
# resizing utils | |
# TODO: clean up later | |
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): | |
h, w = input.shape[-2:] | |
factors = (h / size[0], w / size[1]) | |
# First, we have to determine sigma | |
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 | |
sigmas = ( | |
max((factors[0] - 1.0) / 2.0, 0.001), | |
max((factors[1] - 1.0) / 2.0, 0.001), | |
) | |
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma | |
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 | |
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now | |
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) | |
# Make sure it is odd | |
if (ks[0] % 2) == 0: | |
ks = ks[0] + 1, ks[1] | |
if (ks[1] % 2) == 0: | |
ks = ks[0], ks[1] + 1 | |
input = _gaussian_blur2d(input, ks, sigmas) | |
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) | |
return output | |
def _compute_padding(kernel_size): | |
"""Compute padding tuple.""" | |
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
if len(kernel_size) < 2: | |
raise AssertionError(kernel_size) | |
computed = [k - 1 for k in kernel_size] | |
# for even kernels we need to do asymmetric padding :( | |
out_padding = 2 * len(kernel_size) * [0] | |
for i in range(len(kernel_size)): | |
computed_tmp = computed[-(i + 1)] | |
pad_front = computed_tmp // 2 | |
pad_rear = computed_tmp - pad_front | |
out_padding[2 * i + 0] = pad_front | |
out_padding[2 * i + 1] = pad_rear | |
return out_padding | |
def _filter2d(input, kernel): | |
# prepare kernel | |
b, c, h, w = input.shape | |
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) | |
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
height, width = tmp_kernel.shape[-2:] | |
padding_shape: list[int] = _compute_padding([height, width]) | |
input = torch.nn.functional.pad(input, padding_shape, mode="reflect") | |
# kernel and input tensor reshape to align element-wise or batch-wise params | |
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
# convolve the tensor with the kernel. | |
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
out = output.view(b, c, h, w) | |
return out | |
def _gaussian(window_size: int, sigma): | |
if isinstance(sigma, float): | |
sigma = torch.tensor([[sigma]]) | |
batch_size = sigma.shape[0] | |
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) | |
if window_size % 2 == 0: | |
x = x + 0.5 | |
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) | |
return gauss / gauss.sum(-1, keepdim=True) | |
def _gaussian_blur2d(input, kernel_size, sigma): | |
if isinstance(sigma, tuple): | |
sigma = torch.tensor([sigma], dtype=input.dtype) | |
else: | |
sigma = sigma.to(dtype=input.dtype) | |
ky, kx = int(kernel_size[0]), int(kernel_size[1]) | |
bs = sigma.shape[0] | |
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) | |
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) | |
out_x = _filter2d(input, kernel_x[..., None, :]) | |
out = _filter2d(out_x, kernel_y[..., None]) | |
return out | |
URL_MAP = { | |
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" | |
} | |
CKPT_MAP = { | |
"vgg_lpips": "vgg.pth" | |
} | |
MD5_MAP = { | |
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a" | |
} | |
def download(url, local_path, chunk_size=1024): | |
os.makedirs(os.path.split(local_path)[0], exist_ok=True) | |
with requests.get(url, stream=True) as r: | |
total_size = int(r.headers.get("content-length", 0)) | |
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: | |
with open(local_path, "wb") as f: | |
for data in r.iter_content(chunk_size=chunk_size): | |
if data: | |
f.write(data) | |
pbar.update(chunk_size) | |
def md5_hash(path): | |
with open(path, "rb") as f: | |
content = f.read() | |
return hashlib.md5(content).hexdigest() | |
def get_ckpt_path(name, root, check=False): | |
assert name in URL_MAP | |
path = os.path.join(root, CKPT_MAP[name]) | |
print(md5_hash(path)) | |
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): | |
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) | |
download(URL_MAP[name], path) | |
md5 = md5_hash(path) | |
assert md5 == MD5_MAP[name], md5 | |
return path | |
class KeyNotFoundError(Exception): | |
def __init__(self, cause, keys=None, visited=None): | |
self.cause = cause | |
self.keys = keys | |
self.visited = visited | |
messages = list() | |
if keys is not None: | |
messages.append("Key not found: {}".format(keys)) | |
if visited is not None: | |
messages.append("Visited: {}".format(visited)) | |
messages.append("Cause:\n{}".format(cause)) | |
message = "\n".join(messages) | |
super().__init__(message) | |
def retrieve( | |
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False | |
): | |
"""Given a nested list or dict return the desired value at key expanding | |
callable nodes if necessary and :attr:`expand` is ``True``. The expansion | |
is done in-place. | |
Parameters | |
---------- | |
list_or_dict : list or dict | |
Possibly nested list or dictionary. | |
key : str | |
key/to/value, path like string describing all keys necessary to | |
consider to get to the desired value. List indices can also be | |
passed here. | |
splitval : str | |
String that defines the delimiter between keys of the | |
different depth levels in `key`. | |
default : obj | |
Value returned if :attr:`key` is not found. | |
expand : bool | |
Whether to expand callable nodes on the path or not. | |
Returns | |
------- | |
The desired value or if :attr:`default` is not ``None`` and the | |
:attr:`key` is not found returns ``default``. | |
Raises | |
------ | |
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is | |
``None``. | |
""" | |
keys = key.split(splitval) | |
success = True | |
try: | |
visited = [] | |
parent = None | |
last_key = None | |
for key in keys: | |
if callable(list_or_dict): | |
if not expand: | |
raise KeyNotFoundError( | |
ValueError( | |
"Trying to get past callable node with expand=False." | |
), | |
keys=keys, | |
visited=visited, | |
) | |
list_or_dict = list_or_dict() | |
parent[last_key] = list_or_dict | |
last_key = key | |
parent = list_or_dict | |
try: | |
if isinstance(list_or_dict, dict): | |
list_or_dict = list_or_dict[key] | |
else: | |
list_or_dict = list_or_dict[int(key)] | |
except (KeyError, IndexError, ValueError) as e: | |
raise KeyNotFoundError(e, keys=keys, visited=visited) | |
visited += [key] | |
# final expansion of retrieved value | |
if expand and callable(list_or_dict): | |
list_or_dict = list_or_dict() | |
parent[last_key] = list_or_dict | |
except KeyNotFoundError as e: | |
if default is None: | |
raise e | |
else: | |
list_or_dict = default | |
success = False | |
if not pass_success: | |
return list_or_dict | |
else: | |
return list_or_dict, success |