Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import VideoMAEForVideoClassification,
|
|
|
3 |
import cv2 # OpenCV for video processing
|
4 |
|
5 |
# Model ID for video classification (UCF101 subset)
|
@@ -8,22 +9,25 @@ model_id = "MCG-NJU/videomae-base"
|
|
8 |
def analyze_video(video):
|
9 |
# Extract key frames from the video using OpenCV
|
10 |
frames = extract_key_frames(video)
|
11 |
-
|
12 |
# Load model and feature extractor manually
|
13 |
model = VideoMAEForVideoClassification.from_pretrained(model_id)
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
#
|
17 |
-
|
|
|
18 |
|
19 |
-
|
|
|
|
|
|
|
20 |
results = []
|
21 |
-
for
|
22 |
-
|
23 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
24 |
-
predictions = classifier([frame_rgb]) # Assuming model outputs probabilities
|
25 |
-
# Analyze predictions for insights related to the play
|
26 |
-
result = analyze_predictions_ucf101(predictions)
|
27 |
results.append(result)
|
28 |
|
29 |
# Aggregate results across frames and provide a final analysis
|
@@ -40,24 +44,29 @@ def extract_key_frames(video):
|
|
40 |
for i in range(frame_count):
|
41 |
ret, frame = cap.read()
|
42 |
if ret and i % (fps // 2) == 0: # Extract a frame every half second
|
43 |
-
frames.append(frame)
|
44 |
|
45 |
cap.release()
|
46 |
return frames
|
47 |
|
48 |
-
def analyze_predictions_ucf101(
|
49 |
-
#
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
relevant_actions = ["running", "sliding", "jumping"]
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
else:
|
62 |
return "inconclusive"
|
63 |
|
@@ -80,6 +89,7 @@ interface = gr.Interface(
|
|
80 |
outputs="text",
|
81 |
title="Baseball Play Analysis (UCF101 Subset Exploration)",
|
82 |
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
|
|
|
83 |
)
|
84 |
|
85 |
interface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
|
3 |
+
import torch
|
4 |
import cv2 # OpenCV for video processing
|
5 |
|
6 |
# Model ID for video classification (UCF101 subset)
|
|
|
9 |
def analyze_video(video):
|
10 |
# Extract key frames from the video using OpenCV
|
11 |
frames = extract_key_frames(video)
|
12 |
+
|
13 |
# Load model and feature extractor manually
|
14 |
model = VideoMAEForVideoClassification.from_pretrained(model_id)
|
15 |
+
processor = VideoMAEImageProcessor.from_pretrained(model_id)
|
16 |
+
|
17 |
+
# Prepare frames for the model
|
18 |
+
inputs = processor(images=frames, return_tensors="pt")
|
19 |
|
20 |
+
# Make predictions
|
21 |
+
with torch.no_grad():
|
22 |
+
outputs = model(**inputs)
|
23 |
|
24 |
+
logits = outputs.logits
|
25 |
+
predictions = torch.argmax(logits, dim=-1)
|
26 |
+
|
27 |
+
# Analyze predictions for insights related to the play
|
28 |
results = []
|
29 |
+
for prediction in predictions:
|
30 |
+
result = analyze_predictions_ucf101(prediction.item())
|
|
|
|
|
|
|
|
|
31 |
results.append(result)
|
32 |
|
33 |
# Aggregate results across frames and provide a final analysis
|
|
|
44 |
for i in range(frame_count):
|
45 |
ret, frame = cap.read()
|
46 |
if ret and i % (fps // 2) == 0: # Extract a frame every half second
|
47 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
|
48 |
|
49 |
cap.release()
|
50 |
return frames
|
51 |
|
52 |
+
def analyze_predictions_ucf101(prediction):
|
53 |
+
# Map prediction to action labels (this mapping is hypothetical)
|
54 |
+
action_labels = {
|
55 |
+
0: "running",
|
56 |
+
1: "sliding",
|
57 |
+
2: "jumping",
|
58 |
+
# Add more labels as necessary
|
59 |
+
}
|
60 |
+
action = action_labels.get(prediction, "unknown")
|
61 |
|
62 |
relevant_actions = ["running", "sliding", "jumping"]
|
63 |
+
if action in relevant_actions:
|
64 |
+
if action == "sliding":
|
65 |
+
return "potentially safe"
|
66 |
+
elif action == "running":
|
67 |
+
return "potentially out"
|
68 |
+
else:
|
69 |
+
return "inconclusive"
|
70 |
else:
|
71 |
return "inconclusive"
|
72 |
|
|
|
89 |
outputs="text",
|
90 |
title="Baseball Play Analysis (UCF101 Subset Exploration)",
|
91 |
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
|
92 |
+
share=True
|
93 |
)
|
94 |
|
95 |
interface.launch()
|