import os import base64 import numpy as np from PIL import Image import io import requests import replicate from flask import Flask, request import gradio as gr from openai import OpenAI from dotenv import load_dotenv, find_dotenv # Locate the .env file dotenv_path = find_dotenv() load_dotenv(dotenv_path) OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN') client = OpenAI() def call_openai(pil_image): # Save the PIL image to a bytes buffer buffered = io.BytesIO() pil_image.save(buffered, format="JPEG") # Encode the image to base64 image_data = base64.b64encode(buffered.getvalue()).decode('utf-8') response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "user", "content": [ {"type": "text", "text": "You are a product designer. I've attached a moodboard here. In one sentence, what do all of these elements have in common? Answer from a design language perspective, if you were telling another designer to create something similar, including any repeating colors and materials and shapes and textures"}, { "type": "image_url", "image_url": { "url": "data:image/jpeg;base64," + image_data, }, }, ], } ], max_tokens=300, ) return response.choices[0].message.content def image_classifier(moodboard, prompt): # Convert the numpy array to a PIL image pil_image = Image.fromarray(moodboard.astype('uint8')) openai_response = call_openai(pil_image) openai_response = openai_response.replace('moodboard', '') # Call Stable Diffusion API with the response from OpenAI input = { "width": 768, "height": 768, "prompt": "high quality render of " + prompt + ", " + openai_response[20:], "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch", "refine": "expert_ensemble_refiner", "apply_watermark": False, "num_inference_steps": 25 } output = replicate.run( "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", input=input ) # Download the image from the URL image_url = output[0] print(image_url) response = requests.get(image_url) print(response) img = Image.open(io.BytesIO(response.content)) return img # Return the image object # app = Flask(__name__) # os.environ.get("REPLICATE_API_TOKEN") # @app.route("/") # def index(): demo = gr.Interface(fn=image_classifier, inputs=["image", "text"], outputs="image") demo.launch(share=True)