mertsincan
Create main.py
a514fb1
# https://huggingface.co./docs/transformers/tasks/video_classification
# https://huggingface.co./docs/transformers/model_doc/videomae
from transformers.utils import send_example_telemetry
import pathlib
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
import os
from huggingface_hub import hf_hub_download
import pytorchvideo.data
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
RandomShortSideScale,
RemoveKey,
ShortSideScale,
UniformTemporalSubsample,
)
from torchvision.transforms import (
Compose,
Lambda,
RandomCrop,
RandomHorizontalFlip,
Resize,
)
from transformers import TrainingArguments, Trainer
import evaluate
import numpy as np
import torch
model_ckpt = "MCG-NJU/videomae-base" # pre-trained model from which to fine-tune
batch_size = 8 # batch size for training and evaluation
if __name__ == '__main__':
# get an example videos
# dataset_root_path = "/vol/research/NOBACKUP/CVSSP/scratch_4weeks/om0009/Datasets/MeineDGS_Some_Videos"
# dataset_root_path = pathlib.Path(dataset_root_path)
#
# all_video_file_paths = (
# list(dataset_root_path.glob("train/*.mp4"))
# + list(dataset_root_path.glob("val/*.mp4"))
# + list(dataset_root_path.glob("test/*.mp4"))
# )
# class_labels = ["zero", "one", "two"]
# label2id = {label: i for i, label in enumerate(class_labels)}
# id2label = {i: label for label, i in label2id.items()}
#
# # get model
# image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
# model = VideoMAEForVideoClassification.from_pretrained(
# model_ckpt,
# label2id=label2id,
# id2label=id2label,
# ignore_mismatched_sizes=True,
# # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
# )
#
# # preprocess dataset
# mean = image_processor.image_mean
# std = image_processor.image_std
# if "shortest_edge" in image_processor.size:
# height = width = image_processor.size["shortest_edge"]
# else:
# height = image_processor.size["height"]
# width = image_processor.size["width"]
# resize_to = (height, width)
#
# num_frames_to_sample = model.config.num_frames
# sample_rate = 4
# fps = 30
# clip_duration = num_frames_to_sample * sample_rate / fps
#
# # Training dataset transformations.
# train_transform = Compose(
# [
# ApplyTransformToKey(
# key="video",
# transform=Compose(
# [
# UniformTemporalSubsample(num_frames_to_sample),
# Lambda(lambda x: x / 255.0),
# Normalize(mean, std),
# RandomShortSideScale(min_size=256, max_size=320),
# RandomCrop(resize_to),
# RandomHorizontalFlip(p=0.5),
# ]
# ),
# ),
# ]
# )
#
# a = pytorchvideo.data.ucf101()
#
# print("break")
# Loading the dataset
# hf_dataset_identifier = "sayakpaul/ucf101-subset"
# filename = "UCF101_subset.tar.gz"
# file_path = hf_hub_download(
# repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset"
# )
dataset_root_path = "/home/om0009/.cache/huggingface/hub/datasets--sayakpaul--ucf101-subset/snapshots/b9984b8d2a95e4a1879e1b071e9433858d0bc24a/UCF101_subset"
dataset_root_path = pathlib.Path(dataset_root_path)
video_count_train = len(list(dataset_root_path.glob("train/*/*.avi")))
video_count_val = len(list(dataset_root_path.glob("val/*/*.avi")))
video_count_test = len(list(dataset_root_path.glob("test/*/*.avi")))
video_total = video_count_train + video_count_val + video_count_test
print(f"Total videos: {video_total}")
all_video_file_paths = (
list(dataset_root_path.glob("train/*/*.avi"))
+ list(dataset_root_path.glob("val/*/*.avi"))
+ list(dataset_root_path.glob("test/*/*.avi"))
)
class_labels = sorted({str(path).split("/")[-2] for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}
print(f"Unique classes: {list(label2id.keys())}.")
# loading the model
image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
model = VideoMAEForVideoClassification.from_pretrained(
model_ckpt,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
# provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
# construct the dataset for training
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
# Training dataset transformations.
train_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
RandomHorizontalFlip(p=0.5),
]
),
),
]
)
# Training dataset.
train_dataset = pytorchvideo.data.Ucf101(
data_path=os.path.join(dataset_root_path, "train"),
clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
decode_audio=False,
transform=train_transform,
)
# Validation and evaluation datasets' transformations.
val_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
Resize(resize_to),
]
),
),
]
)
# Validation and evaluation datasets.
val_dataset = pytorchvideo.data.Ucf101(
data_path=os.path.join(dataset_root_path, "val"),
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
decode_audio=False,
transform=val_transform,
)
test_dataset = pytorchvideo.data.Ucf101(
data_path=os.path.join(dataset_root_path, "test"),
clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
decode_audio=False,
transform=val_transform,
)
sample_video = next(iter(train_dataset))
sample_video.keys()
def investigate_video(sample_video):
"""Utility to investigate the keys present in a single video sample."""
for k in sample_video:
if k == "video":
print(k, sample_video["video"].shape)
else:
print(k, sample_video[k])
print(f"Video label: {id2label[sample_video[k]]}")
investigate_video(sample_video)
# training the model
model_name = model_ckpt.split("/")[-1]
new_model_name = f"{model_name}-finetuned-ucf101-subset"
num_epochs = 4
args = TrainingArguments(
new_model_name,
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
max_steps=(train_dataset.num_videos // batch_size) * num_epochs,
)
# args = {
# "new_model_name": new_model_name,
# "remove_unused_columns": False,
# "evaluation_strategy": "epoch",
# "save_strategy" : "epoch",
# "learning_rate" : 5e-5,
# "per_device_train_batch_size" : batch_size,
# "per_device_eval_batch_size" : batch_size,
# "warmup_ratio" : 0.1,
# "logging_steps" : 10,
# "load_best_model_at_end" : True,
# "metric_for_best_model" : "accuracy",
# "push_to_hub" : True,
# "max_steps" : (train_dataset.num_videos // batch_size) * num_epochs
# }
metric = evaluate.load("accuracy")
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
"""Computes accuracy on a batch of predictions."""
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
def collate_fn(examples):
"""The collation function to be used by `Trainer` to prepare data batches."""
# permute to (num_frames, num_channels, height, width)
pixel_values = torch.stack(
[example["video"].permute(1, 0, 2, 3) for example in examples]
)
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
# hugging face login
from huggingface_hub import login
login(token="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")
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=image_processor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
print("break")