Abdo-Alshoki's picture
removed a debugging print line
68a92d9 verified
raw
history blame
3.26 kB
import torch
import gradio as gr
import torch.nn as nn
import torchvision
import cv2
import numpy as np
import tempfile
class MyModel(nn.Module):
def __init__(self, num_classes=1):
super(MyModel, self).__init__() # Initialize nn.Module
self.model = torchvision.models.video.r3d_18(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def preprocess_video(self, video_path, num_frames=40):
"""Preprocess video: sample frames, resize, normalize, and return tensor."""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
sampled_frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if not ret:
continue
frame = cv2.resize(frame, (112, 112)) # Resize to 112x112 for R3D-18
frame = np.transpose(frame, (2, 0, 1)) # Channels-first
sampled_frames.append(frame)
cap.release()
if len(sampled_frames) < num_frames:
padding = np.zeros((num_frames - len(sampled_frames), 3, 112, 112))
sampled_frames = np.concatenate([sampled_frames, padding], axis=0)
# Convert to tensor and rearrange dimensions to (3, num_frames, 112, 112)
return torch.tensor(sampled_frames).float().permute(1, 0, 2, 3).unsqueeze(0)
def forward(self, x):
return self.model(x)
def predict(self, video_path):
"""Test the model on the given videos and compute accuracy."""
self.model.eval()
predictions = []
with torch.no_grad():
X = self.preprocess_video(video_path)
output = self.model(X)
pred = torch.sigmoid(output) # Apply sigmoid for binary classification
# Track predictions
predictions.append(pred.item())
return predictions
def save_model(self, filepath):
torch.save({
'model_state_dict': self.state_dict(),
}, filepath)
@staticmethod
def load_model(filepath, num_classes=1):
model = MyModel(num_classes)
checkpoint = torch.load(filepath, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
model = MyModel().load_model('pre_3D_model.h5')
def classify_video(video):
prob = model.predict(video)
label = "Non-violent" if prob[0] >= 0.5 else "Violent"
violent_prob_percentage = f"{round((1 - prob[0]) * 100, 2)}% chance of being violent"
return label, violent_prob_percentage
# Set up the Gradio interface
interface = gr.Interface(
fn=classify_video,
inputs=gr.Video(), # Allows video upload
outputs=[
gr.Text(label="Classification"), # Label for classification output
gr.Text(label="Violence Probability") # Label for probability output with text
],
title="Violence Detection in Videos",
description="Upload a video to classify it as violent or non-violent with a probability score."
)
interface.launch(share=True, debug=True)