File size: 1,844 Bytes
807a8fb
a190fc6
 
 
0d6882d
a190fc6
807a8fb
 
 
3f72a26
 
807a8fb
 
 
 
 
 
 
 
 
 
0d6882d
 
807a8fb
 
 
 
 
 
 
 
 
 
0d6882d
a190fc6
807a8fb
 
 
 
 
 
 
 
 
 
 
a190fc6
 
 
807a8fb
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import gradio as gr
from ultralytics import YOLO
import numpy as np
import cv2

model_options = ["yolo-8n-shiprs.pt", "yolo-8s-shiprs.pt", "yolo-8m-shiprs.pt"]
model_names = ["Nano", "Small", "Medium"]
models = [YOLO(option) for option in model_options]
example_list = [["examples/" + example] for example in os.listdir("examples")]

def process_image(input_image, model_name, conf):
    if input_image is None:
        return None, "No image"
    
    if model_name is None:
        model_name = model_names[0]
    
    if conf is None:
        conf = 0.6

    input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)

    model_index = model_names.index(model_name)
    model = models[model_index]
    
    results = model.predict(input_image, conf=conf)
    class_counts = {}
    class_counts_str = "Class Counts:\n"

    for r in results:
        im_array = r.plot()
        im_array = im_array.astype(np.uint8)
        im_array = cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB)

        if len(r.obb.cls) == 0:  # If no objects are detected
            return None, "No objects detected."

        for cls in r.obb.cls:
            class_name = r.names[cls.item()]
            class_counts[class_name] = class_counts.get(class_name, 0) + 1

        for cls, count in class_counts.items():
            class_counts_str += f"\n{cls}: {count}"

        return im_array, class_counts_str

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(), 
        gr.Radio(model_names, label="Choose model", value=model_names[0]),
        gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6)
    ], 
    outputs=["image", gr.Textbox(label="More info")],
    title="YOLOv8-obb aerial detection",
    description='''YOLOv8-obb trained on DOTAv1.5''',
    examples=example_list
)

iface.launch()