Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from pytorchvideo.data.encoded_video import EncodedVideo | |
from pytorchvideo.transforms import UniformTemporalSubsample | |
from transformers import VideoMAEForVideoClassification | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as F_t # Changed import | |
# Check device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load pre-trained model | |
model_path = "model" | |
loaded_model = VideoMAEForVideoClassification.from_pretrained(model_path) | |
loaded_model = loaded_model.to(device) | |
loaded_model.eval() | |
# Label names for prediction | |
label_names = [ | |
'Archery', 'BalanceBeam', 'BenchPress', 'ApplyEyeMakeup', 'BasketballDunk', | |
'BandMarching', 'BabyCrawling', 'ApplyLipstick', 'BaseballPitch', 'Basketball' | |
] | |
def load_video(video_path): | |
try: | |
video = EncodedVideo.from_path(video_path) | |
video_data = video.get_clip(start_sec=0, end_sec=video.duration) | |
return video_data['video'] | |
except Exception as e: | |
raise ValueError(f"Error loading video: {str(e)}") | |
def preprocess_video(video_frames): | |
try: | |
# Temporal subsampling | |
transform_temporal = UniformTemporalSubsample(16) | |
video_frames = transform_temporal(video_frames) | |
# Convert to float and normalize to [0, 1] | |
video_frames = video_frames.float() / 255.0 | |
# Ensure channel order is [T, C, H, W] | |
if video_frames.shape[0] == 3: | |
video_frames = video_frames.permute(1, 0, 2, 3) | |
# Normalize using torchvision's functional transform | |
mean = torch.tensor([0.485, 0.456, 0.406]) | |
std = torch.tensor([0.229, 0.224, 0.225]) | |
for t in range(video_frames.shape[0]): | |
video_frames[t] = F_t.normalize(video_frames[t], mean, std) | |
# Resize frames | |
video_frames = torch.stack([ | |
F_t.resize(frame, [224, 224], antialias=True) | |
for frame in video_frames | |
]) | |
# Add batch dimension | |
video_frames = video_frames.unsqueeze(0) | |
return video_frames | |
except Exception as e: | |
raise ValueError(f"Error preprocessing video: {str(e)}") | |
def predict_video(video): | |
try: | |
# Load and preprocess video | |
video_data = load_video(video) | |
processed_video = preprocess_video(video_data) | |
processed_video = processed_video.to(device) | |
# Make predictions | |
with torch.no_grad(): | |
outputs = loaded_model(processed_video) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim=-1)[0] | |
top_3 = torch.topk(probabilities, 3) | |
# Format results | |
results = [ | |
f"{label_names[idx.item()]}: {prob.item():.2%}" | |
for idx, prob in zip(top_3.indices, top_3.values) | |
] | |
return "\n".join(results) | |
except Exception as e: | |
return f"Error processing video: {str(e)}" | |
# Gradio interface | |
iface = gr.Interface( | |
fn=predict_video, | |
inputs=gr.Video(label="Upload Video"), | |
outputs=gr.Textbox(label="Top 3 Predictions"), | |
title="Video Action Recognition", | |
description="Upload a video to classify the action being performed. The model will return the top 3 predictions.", | |
examples=[ | |
["test_video_1.avi"], | |
["test_video_2.avi"], | |
["test_video_3.avi"] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch(debug=True, share=True) |