hocheewai commited on
Commit
d79cbd2
·
1 Parent(s): 6cab04e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -1
app.py CHANGED
@@ -1,3 +1,139 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/hocheewai/videomae-base-finetuned-ucf101-subset").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
  import gradio as gr
3
+ import imutils
4
+ import numpy as np
5
+ import torch
6
+ from pytorchvideo.transforms import (
7
+ ApplyTransformToKey,
8
+ Normalize,
9
+ RandomShortSideScale,
10
+ RemoveKey,
11
+ ShortSideScale,
12
+ UniformTemporalSubsample,
13
+ )
14
+ from torchvision.transforms import (
15
+ Compose,
16
+ Lambda,
17
+ RandomCrop,
18
+ RandomHorizontalFlip,
19
+ Resize,
20
+ )
21
+ from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
22
 
23
+ MODEL_CKPT = "hocheewai/videomae-base-finetuned-ucf101-subset"
24
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
27
+ PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
28
+
29
+ RESIZE_TO = PROCESSOR.size["shortest_edge"]
30
+ NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
31
+ IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]}
32
+ VAL_TRANSFORMS = Compose(
33
+ [
34
+ UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE),
35
+ Lambda(lambda x: x / 255.0),
36
+ Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]),
37
+ Resize((RESIZE_TO, RESIZE_TO)),
38
+ ]
39
+ )
40
+ LABELS = list(MODEL.config.label2id.keys())
41
+
42
+
43
+ def parse_video(video_file):
44
+ """A utility to parse the input videos.
45
+ Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
46
+ """
47
+ vs = cv2.VideoCapture(video_file)
48
+
49
+ # try to determine the total number of frames in the video file
50
+ try:
51
+ prop = (
52
+ cv2.cv.CV_CAP_PROP_FRAME_COUNT
53
+ if imutils.is_cv2()
54
+ else cv2.CAP_PROP_FRAME_COUNT
55
+ )
56
+ total = int(vs.get(prop))
57
+ print("[INFO] {} total frames in video".format(total))
58
+
59
+ # an error occurred while trying to determine the total
60
+ # number of frames in the video file
61
+ except:
62
+ print("[INFO] could not determine # of frames in video")
63
+ print("[INFO] no approx. completion time can be provided")
64
+ total = -1
65
+
66
+ frames = []
67
+
68
+ # loop over frames from the video file stream
69
+ while True:
70
+ # read the next frame from the file
71
+ (grabbed, frame) = vs.read()
72
+ if frame is not None:
73
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
+ frames.append(frame)
75
+ # if the frame was not grabbed, then we have reached the end
76
+ # of the stream
77
+ if not grabbed:
78
+ break
79
+
80
+ return frames
81
+
82
+
83
+ def preprocess_video(frames: list):
84
+ """Utility to apply preprocessing transformations to a video tensor."""
85
+ # Each frame in the `frames` list has the shape: (height, width, num_channels).
86
+ # Collated together the `frames` has the the shape: (num_frames, height, width, num_channels).
87
+ # So, after converting the `frames` list to a torch tensor, we permute the shape
88
+ # such that it becomes (num_channels, num_frames, height, width) to make
89
+ # the shape compatible with the preprocessing transformations. After applying the
90
+ # preprocessing chain, we permute the shape to (num_frames, num_channels, height, width)
91
+ # to make it compatible with the model. Finally, we add a batch dimension so that our video
92
+ # classification model can operate on it.
93
+ video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype))
94
+ video_tensor = video_tensor.permute(
95
+ 3, 0, 1, 2
96
+ ) # (num_channels, num_frames, height, width)
97
+ video_tensor_pp = VAL_TRANSFORMS(video_tensor)
98
+ video_tensor_pp = video_tensor_pp.permute(
99
+ 1, 0, 2, 3
100
+ ) # (num_frames, num_channels, height, width)
101
+ video_tensor_pp = video_tensor_pp.unsqueeze(0)
102
+ return video_tensor_pp.to(DEVICE)
103
+
104
+
105
+ def infer(video_file):
106
+ frames = parse_video(video_file)
107
+ video_tensor = preprocess_video(frames)
108
+ inputs = {"pixel_values": video_tensor}
109
+
110
+ # forward pass
111
+ with torch.no_grad():
112
+ outputs = MODEL(**inputs)
113
+ logits = outputs.logits
114
+ softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
115
+ confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))}
116
+ return confidences
117
+
118
+
119
+ gr.Interface(
120
+ fn=infer,
121
+ inputs=gr.Video(type="file"),
122
+ outputs=gr.Label(num_top_classes=3),
123
+ examples=[
124
+ ["examples/babycrawling.mp4"],
125
+ ["examples/baseball.mp4"],
126
+ ["examples/balancebeam.mp4"],
127
+ ],
128
+ title="VideoMAE fine-tuned on a subset of UCF-101",
129
+ description=(
130
+ "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
131
+ " examples to load them. Read more at the links below."
132
+ ),
133
+ article=(
134
+ "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
135
+ " <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>"
136
+ ),
137
+ allow_flagging=False,
138
+ allow_screenshot=False,
139
+ ).launch()