LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame
6.62 kB
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import os
import subprocess
from typing import Any, Dict, List, Optional
import torch
import torchvision
from torch import Tensor
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(do_round, batch):
batch = list(zip(*batch))
final_batch = {}
final_batch["samples"] = NestedTensor.from_tensor_list(batch[0], do_round)
final_batch["targets"] = batch[1]
if "positive_map" in batch[1][0]:
# we batch the positive maps here
# Since in general each batch element will have a different number of boxes,
# we collapse a single batch dimension to avoid padding. This is sufficient for our purposes.
max_len = max([v["positive_map"].shape[1] for v in batch[1]])
nb_boxes = sum([v["positive_map"].shape[0] for v in batch[1]])
batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool)
cur_count = 0
for v in batch[1]:
cur_pos = v["positive_map"]
batched_pos_map[cur_count : cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos
cur_count += len(cur_pos)
assert cur_count == len(batched_pos_map)
# assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]])
final_batch["positive_map"] = batched_pos_map.float()
if "positive_map_eval" in batch[1][0]:
# we batch the positive maps here
# Since in general each batch element will have a different number of boxes,
# we collapse a single batch dimension to avoid padding. This is sufficient for our purposes.
max_len = max([v["positive_map_eval"].shape[1] for v in batch[1]])
nb_boxes = sum([v["positive_map_eval"].shape[0] for v in batch[1]])
batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool)
cur_count = 0
for v in batch[1]:
cur_pos = v["positive_map_eval"]
batched_pos_map[cur_count : cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos
cur_count += len(cur_pos)
assert cur_count == len(batched_pos_map)
# assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]])
final_batch["positive_map_eval"] = batched_pos_map.float()
if "answer" in batch[1][0] or "answer_type" in batch[1][0]:
answers = {}
for f in batch[1][0].keys():
if "answer" not in f:
continue
answers[f] = torch.stack([b[f] for b in batch[1]])
final_batch["answers"] = answers
return final_batch
class NestedTensor(object):
def __init__(self, tensors, mask):
self.tensors = tensors
self.mask = mask
def to(self, *args, **kwargs):
cast_tensor = self.tensors.to(*args, **kwargs)
cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None
return type(self)(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
@classmethod
def from_tensor_list(cls, tensor_list, do_round=False):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list]))
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = (len(tensor_list),) + max_size
b, c, h, w = batch_shape
if do_round:
# Round to an even size to avoid rounding issues in fpn
p = 128
h = h if h % p == 0 else (h // p + 1) * p
w = w if w % p == 0 else (w // p + 1) * p
batch_shape = b, c, h, w
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return cls(tensor, mask)
def __repr__(self):
return repr(self.tensors)
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
assert input.shape[0] != 0 or input.shape[1] != 0, "At least one of the two first dimensions must be non zero"
if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
return torch.nn.functional.interpolate(input.transpose(0, 1), size, scale_factor, mode, align_corners).transpose(0, 1)
# empty batch dimension is now supported in pytorch
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
def targets_to(targets: List[Dict[str, Any]], device):
"""Moves the target dicts to the given device."""
excluded_keys = [
"questionId",
"tokens_positive",
"tokens",
"dataset_name",
"sentence_id",
"original_img_id",
"nb_eval",
"task_id",
"original_id",
]
return [{k: v.to(device) if k not in excluded_keys else v for k, v in t.items() if k != "caption"} for t in targets]