File size: 3,860 Bytes
d20f4f3
 
 
 
 
 
 
 
 
 
 
 
e615f64
d20f4f3
a3c1650
e615f64
554867f
d20f4f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b56d5
d20f4f3
 
 
 
 
 
 
 
 
 
 
 
 
 
5eb2ba5
d20f4f3
 
 
 
 
f38254a
d20f4f3
 
 
 
 
 
 
 
 
 
 
 
 
554867f
d20f4f3
 
 
 
 
f38254a
d20f4f3
 
 
 
 
3a4430e
 
 
d20f4f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from typing import List
import os
import numpy as np
import supervision as sv
import uuid
import torch
from tqdm import tqdm
import gradio as gr
import torch
import numpy as np
from PIL import Image
import spaces

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)

BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()


def calculate_end_frame_index(source_video_path):
    video_info = sv.VideoInfo.from_video_path(source_video_path)
    return min(
        video_info.total_frames,
        video_info.fps * 2
    )


def annotate_image(
    input_image,
    detections,
    labels
) -> np.ndarray:
    output_image = MASK_ANNOTATOR.annotate(input_image, detections)
    output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
    output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
    return output_image

@spaces.GPU
def process_video(
    input_video,
    labels,
    progress=gr.Progress(track_tqdm=True)
):
    labels = labels.split(",")
    video_info = sv.VideoInfo.from_video_path(input_video)
    total = calculate_end_frame_index(input_video)
    frame_generator = sv.get_video_frames_generator(
        source_path=input_video,
        end=total
    )

    result_file_name = f"{uuid.uuid4()}.mp4"
    result_file_path = os.path.join("./", result_file_name)
    with sv.VideoSink(result_file_path, video_info=video_info) as sink:
        for _ in tqdm(range(total), desc="Processing video.."):
            frame = next(frame_generator)
            # list of dict of {"box": box, "mask":mask, "score":score, "label":label}
            results = query(frame, labels)
            print("results", results)
            detections = sv.Detections.from_transformers(results[0])
            final_labels = []
            for id in results[0]["labels"]:
              final_labels.append(labels[id])
            frame = annotate_image(
                input_image=frame,
                detections=detections,
                labels=final_labels,
            )
            sink.write_frame(frame)
    return result_file_path

def query(image, texts):
  inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
  with torch.no_grad():
    outputs = model(**inputs)
  target_sizes = torch.Tensor([image.shape[:-1]])
  
  results = processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes)
  
  return results



with gr.Blocks() as demo:
  gr.Markdown("## Zero-shot Object Tracking with OWLv2 🦉")
  gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co./google/owlv2-base-patch16-ensemble) model by Google.")
  gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇")
  with gr.Tab(label="Video"):
    with gr.Row():
        input_video = gr.Video(
            label='Input Video'
        )
        output_video = gr.Video(
            label='Output Video'
        )
    with gr.Row():
        candidate_labels = gr.Textbox(
            label='Labels',
            placeholder='Labels separated by a comma',
        )
        submit = gr.Button()
    gr.Examples(
        fn=process_video,
        examples=[["./cats.mp4", "dog,cat"]],
        inputs=[
            input_video,
            candidate_labels,
            
        ],
        outputs=output_video
    )

  submit.click(
      fn=process_video,
      inputs=[input_video, candidate_labels],
      outputs=output_video
  )

demo.launch(debug=False, show_error=True)