Safwanahmad619 commited on
Commit
624970e
·
verified ·
1 Parent(s): 8c8f9ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ from transformers import YolosImageProcessor, YolosForObjectDetection
4
+ from PIL import Image
5
+ import torch
6
+
7
+ # Load model and processor
8
+ model = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny')
9
+ image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
10
+
11
+ def process_frame(frame):
12
+ # Resize the frame to reduce processing time
13
+ frame = cv2.resize(frame, (640, 360)) # downscaling the frame
14
+
15
+ # Convert the frame (numpy array) to PIL image
16
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
17
+
18
+ # Prepare the image for the model
19
+ inputs = image_processor(images=image, return_tensors="pt")
20
+
21
+ # Perform object detection
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+
25
+ # Post-process the outputs to extract bounding boxes and labels
26
+ target_sizes = torch.tensor([image.size[::-1]])
27
+ results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
28
+
29
+ # Draw the bounding boxes on the original frame
30
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
31
+ box = [round(i, 2) for i in box.tolist()]
32
+ cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)
33
+ cv2.putText(frame, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}",
34
+ (int(box[0]), int(box[1])-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
35
+
36
+ return frame
37
+
38
+ def video_object_detection(video):
39
+ cap = cv2.VideoCapture(video)
40
+ processed_frames = []
41
+
42
+ while cap.isOpened():
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+
47
+ # Optionally skip frames to speed up processing
48
+ # if int(cap.get(cv2.CAP_PROP_POS_FRAMES)) % 2 == 0: # Process every 2nd frame
49
+ processed_frame = process_frame(frame)
50
+ processed_frames.append(processed_frame)
51
+
52
+ cap.release()
53
+
54
+ # Convert processed frames to a video for display
55
+ height, width, _ = processed_frames[0].shape
56
+ output_video = cv2.VideoWriter('/tmp/output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 20, (width, height))
57
+
58
+ for frame in processed_frames:
59
+ output_video.write(frame)
60
+
61
+ output_video.release()
62
+
63
+ return '/tmp/output.mp4'
64
+
65
+ # Create Gradio interface with live=True
66
+ iface = gr.Interface(fn=video_object_detection, inputs="video", outputs="video", title="YOLOs-Tiny Video Detection", live=True)
67
+ iface.launch()