File size: 5,080 Bytes
a950ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f2ef1
a950ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw
import torch
import streamlit as st
from model.inference_cpu import inference_case

initial_rectangle = {
    "version": "4.4.0",
    'objects': [
        {
            "type": "rect",
            "version": "4.4.0",
            "originX": "left",
            "originY": "top",
            "left": 50,
            "top": 50,
            "width": 100,
            "height": 100,
            'fill': 'rgba(255, 165, 0, 0.3)', 
            'stroke': '#2909F1', 
            'strokeWidth': 3, 
            'strokeDashArray': None, 
            'strokeLineCap': 'butt', 
            'strokeDashOffset': 0, 
            'strokeLineJoin': 'miter', 
            'strokeUniform': True, 
            'strokeMiterLimit': 4, 
            'scaleX': 1, 
            'scaleY': 1, 
            'angle': 0, 
            'flipX': False, 
            'flipY': False, 
            'opacity': 1, 
            'shadow': None, 
            'visible': True, 
            'backgroundColor': '', 
            'fillRule': 
            'nonzero', 
            'paintFirst': 
            'fill', 
            'globalCompositeOperation': 'source-over', 
            'skewX': 0, 
            'skewY': 0, 
            'rx': 0, 
            'ry': 0
        }
    ]
}

def run():
    image = st.session_state.data_item["image"].float()
    image_zoom_out = st.session_state.data_item["zoom_out_image"].float()
    text_prompt = None
    point_prompt = None
    box_prompt = None
    if st.session_state.use_text_prompt:
        text_prompt = st.session_state.text_prompt
    if st.session_state.use_point_prompt and len(st.session_state.points) > 0:
        point_prompt = reflect_points_into_model(st.session_state.points)
    if st.session_state.use_box_prompt:
        box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox)
    inference_case.clear()
    st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out, 
                                            text_prompt=text_prompt,
                                            _point_prompt=point_prompt,
                                            _box_prompt=box_prompt)

def reflect_box_into_model(box_3d):
    z1, y1, x1, z2, y2, x2 = box_3d
    x1_prompt = int(x1 * 256.0 / 325.0)
    y1_prompt = int(y1 * 256.0 / 325.0)
    z1_prompt = int(z1 * 32.0 / 325.0)
    x2_prompt = int(x2 * 256.0 / 325.0)
    y2_prompt = int(y2 * 256.0 / 325.0)
    z2_prompt = int(z2 * 32.0 / 325.0)
    return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]))

def reflect_json_data_to_3D_box(json_data, view):
    if view == 'xy':
        st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top']
        st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left']
        st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY']
        st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX']
    print(st.session_state.rectangle_3Dbox)

def reflect_points_into_model(points):
    points_prompt_list = []
    for point in points:
        z, y, x = point
        x_prompt = int(x * 256.0 / 325.0)
        y_prompt = int(y * 256.0 / 325.0)
        z_prompt = int(z * 32.0 / 325.0)
        points_prompt_list.append([z_prompt, y_prompt, x_prompt])
    points_prompt = np.array(points_prompt_list)
    points_label = np.ones(points_prompt.shape[0])
    print(points_prompt, points_label)
    return (torch.tensor(points_prompt), torch.tensor(points_label))

def show_points(points_ax, points_label, ax):
    color = 'red' if points_label == 0 else 'blue'
    ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200)

def make_fig(image, preds, point_axs=None, current_idx=None, view=None):
    # Convert A to an image
    image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB")
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(2.0)

    # Create a yellow mask from B
    if preds is not None:
        mask = np.where(preds == 1, 255, 0).astype(np.uint8)
        mask = Image.merge("RGB", 
                           (Image.fromarray(mask), 
                            Image.fromarray(mask), 
                            Image.fromarray(np.zeros_like(mask, dtype=np.uint8))))

        # Overlay the mask on the image
        image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency)
    
    if point_axs is not None:
        draw = ImageDraw.Draw(image)
        radius = 5
        for point in point_axs:
            z, y, x = point
            if view == 'xy' and z == current_idx:
                draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue")
            elif view == 'xz'and y == current_idx:
                draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue")
    return image