File size: 3,541 Bytes
2f7e2aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import UniformTemporalSubsample
from transformers import VideoMAEForVideoClassification
import torch.nn.functional as F
import torchvision.transforms.functional as F_t  # Changed import

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load pre-trained model
model_path = "model"
loaded_model = VideoMAEForVideoClassification.from_pretrained(model_path)
loaded_model = loaded_model.to(device)
loaded_model.eval()

# Label names for prediction
label_names = [
    'Archery', 'BalanceBeam', 'BenchPress', 'ApplyEyeMakeup', 'BasketballDunk',
    'BandMarching', 'BabyCrawling', 'ApplyLipstick', 'BaseballPitch', 'Basketball'
]

def load_video(video_path):
    try:
        video = EncodedVideo.from_path(video_path)
        video_data = video.get_clip(start_sec=0, end_sec=video.duration)
        return video_data['video']
    except Exception as e:
        raise ValueError(f"Error loading video: {str(e)}")

def preprocess_video(video_frames):
    try:
        # Temporal subsampling
        transform_temporal = UniformTemporalSubsample(16)
        video_frames = transform_temporal(video_frames)
        
        # Convert to float and normalize to [0, 1]
        video_frames = video_frames.float() / 255.0

        # Ensure channel order is [T, C, H, W]
        if video_frames.shape[0] == 3:
            video_frames = video_frames.permute(1, 0, 2, 3)

        # Normalize using torchvision's functional transform
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        
        for t in range(video_frames.shape[0]):
            video_frames[t] = F_t.normalize(video_frames[t], mean, std)

        # Resize frames
        video_frames = torch.stack([
            F_t.resize(frame, [224, 224], antialias=True) 
            for frame in video_frames
        ])
        
        # Add batch dimension
        video_frames = video_frames.unsqueeze(0)
        return video_frames
    except Exception as e:
        raise ValueError(f"Error preprocessing video: {str(e)}")

def predict_video(video):
    try:
        # Load and preprocess video
        video_data = load_video(video)
        processed_video = preprocess_video(video_data)
        processed_video = processed_video.to(device)

        # Make predictions
        with torch.no_grad():
            outputs = loaded_model(processed_video)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)[0]
            top_3 = torch.topk(probabilities, 3)

        # Format results
        results = [
            f"{label_names[idx.item()]}: {prob.item():.2%}"
            for idx, prob in zip(top_3.indices, top_3.values)
        ]
        return "\n".join(results)
    except Exception as e:
        return f"Error processing video: {str(e)}"

# Gradio interface
iface = gr.Interface(
    fn=predict_video,
    inputs=gr.Video(label="Upload Video"),
    outputs=gr.Textbox(label="Top 3 Predictions"),
    title="Video Action Recognition",
    description="Upload a video to classify the action being performed. The model will return the top 3 predictions.",
    examples=[
        ["test_video_1.avi"],
        ["test_video_2.avi"],
        ["test_video_3.avi"]
    ]
)

if __name__ == "__main__":
    iface.launch(debug=True, share=True)