import os import base64 import numpy as np from PIL import Image, ImageChops, ImageDraw import io import requests import replicate import gradio as gr from dotenv import load_dotenv, find_dotenv # Locate the .env file dotenv_path = find_dotenv() load_dotenv(dotenv_path) REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN') def generate_pattern(image, prompt): # Convert the numpy array to a PIL image starter_image_pil = Image.fromarray(image.astype('uint8')) # Resize the starter image if either dimension is larger than 768 pixels if starter_image_pil.size[0] > 768 or starter_image_pil.size[1] > 768: # Calculate the new size while maintaining the aspect ratio if starter_image_pil.size[0] > starter_image_pil.size[1]: # Width is larger than height new_width = 768 new_height = int((768 / starter_image_pil.size[0]) * starter_image_pil.size[1]) else: # Height is larger than width new_height = 768 new_width = int((768 / starter_image_pil.size[1]) * starter_image_pil.size[0]) # Resize the image starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS) # Move the image horizontally and vertically by 50% width, height = starter_image_pil.size horizontal_shift = width // 2 vertical_shift = height // 2 transformed_image_pil = ImageChops.offset(starter_image_pil, horizontal_shift, vertical_shift) # Create a new image with black background and white cross cross_image_pil = Image.new('RGB', (width, height), 'black') draw = ImageDraw.Draw(cross_image_pil) line_width = 50 # Draw vertical line draw.rectangle([(width // 2 - line_width // 2, 0), (width // 2 + line_width // 2, height)], fill='white') # Draw horizontal line draw.rectangle([(0, height // 2 - line_width // 2), (width, height // 2 + line_width // 2)], fill='white') buffered = io.BytesIO() transformed_image_pil.save(buffered, format="JPEG") image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') buffered = io.BytesIO() cross_image_pil.save(buffered, format="JPEG") cross_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') input = { "prompt": prompt + " smooth background", "negative_prompt": "worst quality, low quality, cartoons, sketch, ugly, lowres", "image": "data:image/jpeg;base64," + image_base64, "mask": "data:image/jpeg;base64," + cross_base64, "num_inference_steps": 25, "num_outputs": 3, } output = replicate.run( "lucataco/sdxl-inpainting:a5b13068cc81a89a4fbeefeccc774869fcb34df4dbc92c1555e0f2771d49dde7", input=input ) images = [] for i in range(min(len(output), 3)): image_url = output[i] response = requests.get(image_url) images.append(Image.open(io.BytesIO(response.content))) # Add empty images if fewer than 3 were returned while len(images) < 3: images.append(Image.new('RGB', (width, height), 'gray')) return images demo = gr.Interface(fn=generate_pattern, inputs=["image", "text"], outputs=["image", "image", "image"]) demo.launch(share=False)