moyanxinxu commited on
Commit
ff18d07
1 Parent(s): 981f0e9

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_video/aerial.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ demo_video/blurry.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ demo_video/high-way.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import supervision as sv
3
+ from func import detect_and_track
4
+ from transformers import DetrForObjectDetection, DetrImageProcessor
5
+
6
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
7
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
8
+ tracker = sv.ByteTrack()
9
+
10
+ mask_annotator = sv.MaskAnnotator()
11
+ bbox_annotator = sv.BoundingBoxAnnotator()
12
+ label_annotator = sv.LabelAnnotator()
13
+
14
+
15
+ def process_video(video_path, confidence_threshold):
16
+ return detect_and_track(
17
+ video_path,
18
+ model,
19
+ processor,
20
+ tracker,
21
+ confidence_threshold,
22
+ mask_annotator,
23
+ bbox_annotator,
24
+ label_annotator,
25
+ )
26
+
27
+
28
+ with gr.Blocks() as demo:
29
+ with gr.Row():
30
+ with gr.Column():
31
+ in_video = gr.Video(
32
+ label="待检测视频",
33
+ show_download_button=True,
34
+ show_share_button=True,
35
+ )
36
+ slide_cofidence = gr.Slider(
37
+ minimum=0.0, maximum=1.0, value=0.8, label="置信度阈值"
38
+ )
39
+ examples = gr.Examples(
40
+ examples=[
41
+ "./demo_video/blurry.mp4",
42
+ "./demo_video/high-way.mp4",
43
+ "./demo_video/aerial.mp4",
44
+ ],
45
+ inputs=in_video,
46
+ label="案例视频",
47
+ )
48
+ with gr.Column():
49
+ out_video = gr.Video(
50
+ label="检测结果视频",
51
+ interactive=False,
52
+ show_download_button=True,
53
+ show_share_button=True,
54
+ )
55
+ combine_video = gr.Video(
56
+ interactive=False,
57
+ label="前后对比",
58
+ show_download_button=True,
59
+ show_share_button=True,
60
+ )
61
+
62
+ start_detect = gr.Button(value="开始检测")
63
+
64
+ start_detect.click(
65
+ fn=process_video,
66
+ inputs=[in_video, slide_cofidence],
67
+ outputs=[out_video, combine_video],
68
+ )
69
+
70
+ demo.launch()
demo_video/aerial.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a93ee4da5ecd552b615579afee630536b9e5fcb68c22f6b3e150e1b8440ffd8c
3
+ size 14558955
demo_video/blurry.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15a863292935dc5d9cc917bdb9623991c97abbd73f524fb98c1f794c01c17338
3
+ size 4608340
demo_video/high-way.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a4f00da283781c89fdef525eeae99ed8c72167de2b73bf4e82a9b8cfbae6378
3
+ size 1966824
func.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2 as cv
4
+ import moviepy.editor as mpe
5
+ import numpy as np
6
+ import supervision as sv
7
+ import torch
8
+ from hyper import hp
9
+ from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+
14
+ def detect(frame, model, processor, confidence_threshold):
15
+ """
16
+ args:
17
+ image: PIL image
18
+ model: PreTrainedModel
19
+ processor: PreTrainedProcessor
20
+ confidence_threshold: float
21
+ returns:
22
+ results: dict with keys "boxes", "labels", "scores"
23
+
24
+
25
+ examples:
26
+ [
27
+ {
28
+ "scores": tensor([0.9980, 0.9039, 0.7575, 0.9033]),
29
+ "labels": tensor([86, 64, 67, 67]),
30
+ "boxes": tensor(
31
+ [
32
+ [1.1582e03, 1.1893e03, 1.9373e03, 1.9681e03],
33
+ [2.4274e02, 1.3234e02, 2.5919e03, 1.9628e03],
34
+ [1.1107e-01, 1.5105e03, 3.1980e03, 2.1076e03],
35
+ [7.1036e-01, 1.7360e03, 3.1970e03, 2.1100e03],
36
+ ]
37
+ ),
38
+ }
39
+ ]
40
+ """
41
+ inputs = processor(images=frame, return_tensors="pt").to(hp.device)
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ target_sizes = torch.tensor([frame.size[::-1]])
45
+
46
+ results = processor.post_process_object_detection(
47
+ outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes
48
+ )
49
+ return results
50
+
51
+
52
+ def get_len_frames(viedo_path):
53
+ """
54
+ args:
55
+ viedo_path: str
56
+ returns:
57
+ int: the number of frames in the video
58
+ examples:
59
+ get_len_frames("../demo_video/aerial.mp4") # 1478
60
+ """
61
+ video_info = sv.VideoInfo.from_video_path(viedo_path)
62
+ return video_info.total_frames
63
+
64
+
65
+ def track(detected_result, tracker: sv.ByteTrack):
66
+ """
67
+ args:
68
+ detected_result: dict with keys "boxes", "labels", "scores"
69
+ tracker: sv.ByteTrack
70
+ returns:
71
+ tracked_result: dict with keys "boxes", "labels", "scores"
72
+ examples:
73
+ from transformers import DetrImageProcessor, DetrForObjectDetection
74
+
75
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
76
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
77
+
78
+ tracker = sv.ByteTrack()
79
+
80
+ image = Image.open("ZJF990.jpg")
81
+
82
+ detected_result = detect(image, model, processor, hp.confidence_threshold)
83
+ tracked_result = track(detected_result, tracker)
84
+
85
+ print(detected_result)
86
+ print(tracked_result)
87
+
88
+
89
+ [
90
+ {
91
+ "scores": tensor([0.9980, 0.9039, 0.7575, 0.9033]),
92
+ "labels": tensor([86, 64, 67, 67]),
93
+ "boxes": tensor(
94
+ [
95
+ [1.1582e03, 1.1893e03, 1.9373e03, 1.9681e03],
96
+ [2.4274e02, 1.3234e02, 2.5919e03, 1.9628e03],
97
+ [1.1107e-01, 1.5105e03, 3.1980e03, 2.1076e03],
98
+ [7.1036e-01, 1.7360e03, 3.1970e03, 2.1100e03],
99
+ ]
100
+ ),
101
+ }
102
+ ]
103
+
104
+
105
+
106
+ Detections(
107
+ xyxy=array(
108
+ [
109
+ [1.1581914e03, 1.1892766e03, 1.9372931e03, 1.9680990e03],
110
+ [2.4273552e02, 1.3233553e02, 2.5918860e03, 1.9628494e03],
111
+ [1.1106834e-01, 1.5105106e03, 3.1980032e03, 2.1075664e03],
112
+ [7.1036065e-01, 1.7359819e03, 3.1970449e03, 2.1100107e03],
113
+ ],
114
+ dtype=float32,
115
+ ),
116
+ mask=None,
117
+ confidence=array([0.9980374, 0.9038882, 0.7575455, 0.9032779], dtype=float32),
118
+ class_id=array([86, 64, 67, 67]),
119
+ tracker_id=array([1, 2, 3, 4]),
120
+ data={},
121
+ )
122
+
123
+
124
+ """
125
+
126
+ detections = sv.Detections.from_transformers(detected_result[0])
127
+ detections = tracker.update_with_detections(detections)
128
+ return detections
129
+
130
+
131
+ def annotate_image(
132
+ frame,
133
+ detections,
134
+ labels,
135
+ mask_annotator: sv.MaskAnnotator,
136
+ bbox_annotator: sv.BoxAnnotator,
137
+ label_annotator: sv.LabelAnnotator,
138
+ ) -> np.ndarray:
139
+ out_frame = mask_annotator.annotate(frame, detections)
140
+ out_frame = bbox_annotator.annotate(out_frame, detections)
141
+ out_frame = label_annotator.annotate(out_frame, detections, labels=labels)
142
+ return out_frame
143
+
144
+
145
+ def detect_and_track(
146
+ video_path,
147
+ model,
148
+ processor,
149
+ tracker,
150
+ confidence_threshold,
151
+ mask_annotator: sv.MaskAnnotator,
152
+ bbox_annotator: sv.BoxAnnotator,
153
+ label_annotator: sv.LabelAnnotator,
154
+ ):
155
+ video_info = sv.VideoInfo.from_video_path(video_path)
156
+ fps = video_info.fps
157
+ len_frames = video_info.total_frames
158
+
159
+ frames_loader = sv.get_video_frames_generator(video_path, end=len_frames)
160
+
161
+ result_file_name = "output.mp4"
162
+ original_file_name = "original.mp4"
163
+ combined_file_name = "combined.mp4"
164
+ result_file_path = os.path.join("../output/", result_file_name)
165
+ original_file_path = os.path.join("../output/", original_file_name)
166
+ combined_file_name = os.path.join("../output/", combined_file_name)
167
+
168
+ concated_frames = []
169
+ original_frames = []
170
+ for frame in tqdm(frames_loader, total=len_frames):
171
+ results = detect(Image.fromarray(frame), model, processor, confidence_threshold)
172
+ tracked_results = track(results, tracker)
173
+ frame = cv.cvtColor(frame, cv.COLOR_RGB2BGR)
174
+ original_frames.append(frame.copy())
175
+ scores = tracked_results.confidence.tolist()
176
+ labels = tracked_results.class_id.tolist()
177
+
178
+ frame = annotate_image(
179
+ frame,
180
+ tracked_results,
181
+ labels=[
182
+ str(f"{model.config.id2label[label]}-{score:.2f}")
183
+ for label, score in zip(labels, scores)
184
+ ],
185
+ mask_annotator=mask_annotator,
186
+ bbox_annotator=bbox_annotator,
187
+ label_annotator=label_annotator,
188
+ )
189
+ concated_frames.append(frame) # Add the processed frame to the list
190
+
191
+ # Create a MoviePy video clip from the list of frames
192
+ original_video = mpe.ImageSequenceClip(original_frames, fps=fps)
193
+ original_video.write_videofile(original_file_path, codec="libx264", fps=fps)
194
+ concated_video = mpe.ImageSequenceClip(concated_frames, fps=fps)
195
+ concated_video.write_videofile(result_file_path, codec="libx264", fps=fps)
196
+
197
+ combined_video = combine_frames(original_frames, concated_frames, fps)
198
+ combined_video.write_videofile(combined_file_name, codec="libx264", fps=fps)
199
+ return result_file_path, combined_file_name
200
+
201
+
202
+ def combine_frames(frames_list1, frames_list2, fps):
203
+ """
204
+ args:
205
+ frames_list1: list of PIL images
206
+ frames_list2: list of PIL images
207
+ returns:
208
+ final_clip: moviepy video clip
209
+ """
210
+ clip1 = ImageSequenceClip(frames_list1, fps=fps)
211
+ clip2 = ImageSequenceClip(frames_list2, fps=fps)
212
+
213
+ final_clip = mpe.clips_array([[clip1, clip2]])
214
+
215
+ return final_clip
hyper.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class hp:
2
+ device = "cpu"
3
+ confidence_threshold = 0.7
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ easydict
3
+ gradio
4
+ moviepy
5
+ numpy
6
+ opencv-python
7
+ scipy
8
+ supervision
9
+ timm
10
+ torch
11
+ torchvision
12
+ git+https://github.com/qubvel/transformers.git@fix-rt-detr-init