|
|
|
|
|
|
|
|
|
from transformers.utils import send_example_telemetry |
|
import pathlib |
|
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
|
|
import pytorchvideo.data |
|
from pytorchvideo.transforms import ( |
|
ApplyTransformToKey, |
|
Normalize, |
|
RandomShortSideScale, |
|
RemoveKey, |
|
ShortSideScale, |
|
UniformTemporalSubsample, |
|
) |
|
from torchvision.transforms import ( |
|
Compose, |
|
Lambda, |
|
RandomCrop, |
|
RandomHorizontalFlip, |
|
Resize, |
|
) |
|
from transformers import TrainingArguments, Trainer |
|
import evaluate |
|
import numpy as np |
|
import torch |
|
|
|
model_ckpt = "MCG-NJU/videomae-base" |
|
batch_size = 8 |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_root_path = "/home/om0009/.cache/huggingface/hub/datasets--sayakpaul--ucf101-subset/snapshots/b9984b8d2a95e4a1879e1b071e9433858d0bc24a/UCF101_subset" |
|
dataset_root_path = pathlib.Path(dataset_root_path) |
|
|
|
video_count_train = len(list(dataset_root_path.glob("train/*/*.avi"))) |
|
video_count_val = len(list(dataset_root_path.glob("val/*/*.avi"))) |
|
video_count_test = len(list(dataset_root_path.glob("test/*/*.avi"))) |
|
video_total = video_count_train + video_count_val + video_count_test |
|
print(f"Total videos: {video_total}") |
|
|
|
all_video_file_paths = ( |
|
list(dataset_root_path.glob("train/*/*.avi")) |
|
+ list(dataset_root_path.glob("val/*/*.avi")) |
|
+ list(dataset_root_path.glob("test/*/*.avi")) |
|
) |
|
|
|
class_labels = sorted({str(path).split("/")[-2] for path in all_video_file_paths}) |
|
label2id = {label: i for i, label in enumerate(class_labels)} |
|
id2label = {i: label for label, i in label2id.items()} |
|
|
|
print(f"Unique classes: {list(label2id.keys())}.") |
|
|
|
|
|
image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt) |
|
model = VideoMAEForVideoClassification.from_pretrained( |
|
model_ckpt, |
|
label2id=label2id, |
|
id2label=id2label, |
|
ignore_mismatched_sizes=True, |
|
|
|
) |
|
|
|
|
|
mean = image_processor.image_mean |
|
std = image_processor.image_std |
|
if "shortest_edge" in image_processor.size: |
|
height = width = image_processor.size["shortest_edge"] |
|
else: |
|
height = image_processor.size["height"] |
|
width = image_processor.size["width"] |
|
resize_to = (height, width) |
|
|
|
num_frames_to_sample = model.config.num_frames |
|
sample_rate = 4 |
|
fps = 30 |
|
clip_duration = num_frames_to_sample * sample_rate / fps |
|
|
|
|
|
train_transform = Compose( |
|
[ |
|
ApplyTransformToKey( |
|
key="video", |
|
transform=Compose( |
|
[ |
|
UniformTemporalSubsample(num_frames_to_sample), |
|
Lambda(lambda x: x / 255.0), |
|
Normalize(mean, std), |
|
RandomShortSideScale(min_size=256, max_size=320), |
|
RandomCrop(resize_to), |
|
RandomHorizontalFlip(p=0.5), |
|
] |
|
), |
|
), |
|
] |
|
) |
|
|
|
|
|
train_dataset = pytorchvideo.data.Ucf101( |
|
data_path=os.path.join(dataset_root_path, "train"), |
|
clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration), |
|
decode_audio=False, |
|
transform=train_transform, |
|
) |
|
|
|
|
|
val_transform = Compose( |
|
[ |
|
ApplyTransformToKey( |
|
key="video", |
|
transform=Compose( |
|
[ |
|
UniformTemporalSubsample(num_frames_to_sample), |
|
Lambda(lambda x: x / 255.0), |
|
Normalize(mean, std), |
|
Resize(resize_to), |
|
] |
|
), |
|
), |
|
] |
|
) |
|
|
|
|
|
val_dataset = pytorchvideo.data.Ucf101( |
|
data_path=os.path.join(dataset_root_path, "val"), |
|
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration), |
|
decode_audio=False, |
|
transform=val_transform, |
|
) |
|
|
|
test_dataset = pytorchvideo.data.Ucf101( |
|
data_path=os.path.join(dataset_root_path, "test"), |
|
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration), |
|
decode_audio=False, |
|
transform=val_transform, |
|
) |
|
|
|
sample_video = next(iter(train_dataset)) |
|
sample_video.keys() |
|
|
|
|
|
def investigate_video(sample_video): |
|
"""Utility to investigate the keys present in a single video sample.""" |
|
for k in sample_video: |
|
if k == "video": |
|
print(k, sample_video["video"].shape) |
|
else: |
|
print(k, sample_video[k]) |
|
|
|
print(f"Video label: {id2label[sample_video[k]]}") |
|
|
|
|
|
investigate_video(sample_video) |
|
|
|
|
|
|
|
model_name = model_ckpt.split("/")[-1] |
|
new_model_name = f"{model_name}-finetuned-ucf101-subset" |
|
num_epochs = 4 |
|
|
|
args = TrainingArguments( |
|
new_model_name, |
|
remove_unused_columns=False, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
learning_rate=5e-5, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
warmup_ratio=0.1, |
|
logging_steps=10, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="accuracy", |
|
push_to_hub=False, |
|
max_steps=(train_dataset.num_videos // batch_size) * num_epochs, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric = evaluate.load("accuracy") |
|
|
|
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
"""Computes accuracy on a batch of predictions.""" |
|
predictions = np.argmax(eval_pred.predictions, axis=1) |
|
return metric.compute(predictions=predictions, references=eval_pred.label_ids) |
|
|
|
|
|
def collate_fn(examples): |
|
"""The collation function to be used by `Trainer` to prepare data batches.""" |
|
|
|
pixel_values = torch.stack( |
|
[example["video"].permute(1, 0, 2, 3) for example in examples] |
|
) |
|
labels = torch.tensor([example["label"] for example in examples]) |
|
return {"pixel_values": pixel_values, "labels": labels} |
|
|
|
|
|
from huggingface_hub import login |
|
login(token="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") |
|
|
|
trainer = Trainer( |
|
model, |
|
args, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset, |
|
tokenizer=image_processor, |
|
compute_metrics=compute_metrics, |
|
data_collator=collate_fn, |
|
) |
|
|
|
print("break") |
|
|