File size: 8,660 Bytes
56d5009
 
 
60e6d37
395677c
56d5009
 
395677c
56d5009
395677c
a2ff906
56d5009
395677c
1245942
395677c
a2ff906
395677c
 
 
a2ff906
395677c
 
 
 
 
 
 
 
 
a2ff906
395677c
 
 
 
b7b4a86
60e6d37
 
395677c
 
363358e
395677c
 
363358e
b7b4a86
60e6d37
182440d
363358e
7c5378a
363358e
 
 
 
60e6d37
363358e
 
b7b4a86
 
182440d
 
363358e
 
60e6d37
363358e
 
8f6d1a1
395677c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a5a54d
395677c
 
7a5a54d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395677c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aead6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import gradio as gr
import openai
from openai import OpenAI
from PIL import Image, ImageEnhance
import cv2
import torch
from transformers import CLIPProcessor, CLIPModel
import requests
from io import BytesIO


# Set OpenAI API Key

openai.api_key = os.getenv("OPENAI_API_KEY")

# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# Expanded object labels
object_labels = [
    "cat", "dog", "house", "tree", "car", "mountain", "flower", "bird", "person", "robot",
    "a digital artwork", "a portrait", "a landscape", "a futuristic cityscape", "horse", 
    "lion", "tiger", "elephant", "giraffe", "airplane", "train", "ship", "book", "laptop",
    "keyboard", "pen", "clock", "cup", "bottle", "backpack", "chair", "table", "sofa", 
    "bed", "building", "street", "forest", "desert", "waterfall", "sunset", "beach", 
    "bridge", "castle", "statue", "3D model"
]

# Example image for contrast check
EXAMPLE_IMAGE_URL = "https://www.watercoloraffair.com/wp-content/uploads/2023/04/monet-houses-of-parliament-low-key.jpg"  # Square example image
example_image = Image.open(BytesIO(requests.get(EXAMPLE_IMAGE_URL).content))

# Initialize the OpenAI client
client = OpenAI()

def process_chat(user_text):
    if not user_text.strip():
        yield "⚠️ Please enter a valid question."
        return

    try:
        # Use the OpenAI client for creating a chat completion
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a helpful assistant named Diane specializing in digital art advice. Don't use text styling (i.e., bold, italics."},
                {"role": "user", "content": user_text},
            ],
            stream=True  # Enable streaming
        )

        response_text = ""
        for chunk in response:
            # Extract the content correctly
            delta = chunk.choices[0].delta  # Get the delta object
            token = getattr(delta, "content", None)  # Safely get the "content" field
            if token:  # Only process non-None tokens
                response_text += token
                yield response_text

    except Exception as e:
        yield f"❌ An error occurred: {str(e)}"


# Function to analyze image contrast
def analyze_contrast_opencv(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    contrast = img.std()
    return contrast

# Function to identify objects using CLIP
def identify_objects_with_clip(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = clip_processor(text=object_labels, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = clip_model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1).numpy().flatten()
    best_match_label = object_labels[probs.argmax()]
    return best_match_label

# Function to enhance image contrast
def enhance_contrast(image):
    enhancer = ImageEnhance.Contrast(image)
    enhanced_image = enhancer.enhance(1.5)
    enhanced_path = "enhanced_image.png"
    enhanced_image.save(enhanced_path)
    return enhanced_path

def provide_suggestions_streaming(object_identified):
    if not object_identified:
        yield "⚠️ Sorry, I couldn't identify an object in your artwork. Try uploading a different image."
        return

    try:
        # Use the OpenAI client for suggestions
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert digital art advisor."},
                {"role": "user", "content": f"Suggest ways to improve a digital artwork featuring a {object_identified}."},
            ],
            stream=True  # Enable streaming
        )

        response_text = ""
        for chunk in response:
            # Extract the content safely
            delta = chunk.choices[0].delta  # Get the delta object
            token = getattr(delta, "content", None)  # Safely access the "content" field
            if token:  # Only process non-None tokens
                response_text += token
                yield response_text

    except Exception as e:
        yield f"❌ An error occurred while providing suggestions: {str(e)}"

# Main image processing function
def process_image(image):
    if not image:
        return "⚠️ Please upload an image.", None, None
    image.save("uploaded_image.png")
    contrast = analyze_contrast_opencv("uploaded_image.png")
    object_identified = identify_objects_with_clip("uploaded_image.png")
    if contrast < 25:
        enhanced_image_path = enhance_contrast(Image.open("uploaded_image.png"))
        return (
            f"Hey, great artwork of {object_identified}! However, it looks like the contrast is a little low. I've improved the contrast for you. ✨",
            enhanced_image_path,
            object_identified
        )
    return (
        f"Hey, great artwork of {object_identified}! Looks like the color contrast is great. Be proud of yourself! 🌟",
        None,
        object_identified
    )

# Gradio Blocks Interface
demo = gr.Blocks(css="""
    #upload-image, #example-image {
        height: 300px !important;
    }
    .button {
        height: 50px;
        font-size: 16px;
    }
""")

with demo:
    gr.Markdown("## 🎨 DIANE (Digital Imaging and Art Neural Enhancer)")
    gr.Markdown("DIANE is here to assist you in refining your digital art. She can answer questions about digital art, analyze your images, and provide creative suggestions to enhance your work.")

    # Chatbot Section
    with gr.Row():
        with gr.Column():
            gr.Markdown("### πŸ’¬ Ask me about digital art")
            user_text = gr.Textbox(label="Enter your question", placeholder="What is the best tool for a beginner?...")
            chat_output = gr.Textbox(label="Answer", interactive=False)
            chat_button = gr.Button("Ask", elem_classes="button")
        
        chat_button.click(process_chat, inputs=user_text, outputs=chat_output)

    # Image Analysis Section
    with gr.Row():
        with gr.Column():
            gr.Markdown("### πŸ–ΌοΈ Upload an image to check its contrast levels")
            with gr.Row(equal_height=True):
                # Left: Image upload field
                with gr.Column():
                    image_input = gr.Image(label="Upload an image", type="pil", elem_id="upload-image")
                    image_button = gr.Button("Check", elem_classes="button")
                
                # Right: Example image field
                with gr.Column():
                    gr.Image(value=example_image, label="Example Image", interactive=False, elem_id="example-image")
                    example_button = gr.Button("Use Example Image", elem_classes="button")
            image_output_text = gr.Textbox(label="Analysis", interactive=False)
            image_output_image = gr.Image(label="Improved Image", interactive=False)
            suggestion_button = gr.Button("I want to improve this artwork. Any suggestions?", visible=False)
            suggestions_output = gr.Textbox(label="Suggestions", interactive=True)
            state_object = gr.State()  # To store identified object

        # Load example image into the input
        def use_example_image():
            return example_image
        
        example_button.click(
            use_example_image,
            inputs=None,
            outputs=image_input
        )

        # Analyze button
        def update_suggestions_visibility(analysis, enhanced_image, identified_object):
            return gr.update(visible=True), analysis, enhanced_image

        image_button.click(
            process_image,
            inputs=image_input,
            outputs=[
                image_output_text, 
                image_output_image, 
                state_object
            ]
        )

        # Automatically enable suggestions after image processing
        image_button.click(
            update_suggestions_visibility,
            inputs=[image_output_text, image_output_image, state_object],
            outputs=[suggestion_button, image_output_text, image_output_image]
        )

        # Suggestion button functionality with streaming
        suggestion_button.click(
            provide_suggestions_streaming,
            inputs=state_object,
            outputs=suggestions_output
        )

demo.launch(share=True)