import os import torch from PIL import Image import gradio as gr from transformers import DetrImageProcessor, DetrForObjectDetection from diffusers import StableDiffusionPipeline from huggingface_hub import login from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # Retrieve Hugging Face token from environment variable HF_TOKEN = os.getenv('HF_TOKEN') if HF_TOKEN is None: raise ValueError("Hugging Face token not found in environment variables.") # Login to Hugging Face using the token login(token=HF_TOKEN) # Load DETR model for object detection def load_detr_model(): try: model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50') processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50') return model, processor, None except Exception as e: return None, None, f"Error loading DETR model: {str(e)}" detr_model, detr_processor, detr_error = load_detr_model() def detect_objects(image): if image is None: return None, "Invalid image: image is None." if detr_model is not None and detr_processor is not None: try: inputs = detr_processor(images=image, return_tensors="pt") outputs = detr_model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] detected_objects = [ {"label": detr_model.config.id2label[label.item()], "box": box.tolist()} for label, box in zip(results['labels'], results['boxes']) ] return detected_objects, None except Exception as e: return None, f"Error in detect_objects: {str(e)}" else: return None, "DETR models not loaded. Skipping object detection." # Load Stable Diffusion model for image generation def load_stable_diffusion_model(): try: device = "cuda" if torch.cuda.is_available() else "cpu" pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device) return pipeline, None except Exception as e: return None, f"Error loading Stable Diffusion model: {str(e)}" sd_pipeline, sd_error = load_stable_diffusion_model() def adjust_dimensions(width, height): # Adjust width and height to be divisible by 8 adjusted_width = (width // 8) * 8 adjusted_height = (height // 8) * 8 return adjusted_width, adjusted_height def generate_image(prompt, width, height): if sd_pipeline is not None: try: adjusted_width, adjusted_height = adjust_dimensions(width, height) image = sd_pipeline(prompt, width=adjusted_width, height=adjusted_height).images[0] # Resize back to original dimensions if needed image = image.resize((width, height), Image.LANCZOS) return image, None except Exception as e: return None, f"Error in generate_image: {str(e)}" else: return None, "Stable Diffusion model not loaded. Skipping image generation." def process_image(image): if image is None: return None, "Invalid image: image is None." try: # Detect objects in the provided image detected_objects, detect_error = detect_objects(image) if detect_error: return None, detect_error # Create a prompt based on detected objects prompt = "modern redesign of an interior room with " if detected_objects: prompt += ", ".join([obj['label'] for obj in detected_objects]) else: prompt += "empty room" # Generate a redesigned image based on the prompt width, height = image.size generated_image, gen_image_error = generate_image(prompt, width, height) if gen_image_error: return None, gen_image_error return generated_image, None except Exception as e: return None, f"Error in process_image: {str(e)}" # Custom CSS for styling custom_css = """ body { background-color: black; } h1 { background: linear-gradient(to right, blue, purple); -webkit-background-clip: text; color: transparent; font-size: 3em; text-align: center; margin-bottom: 20px; } """ # Creating the Gradio interface with custom styling iface = gr.Interface( fn=process_image, inputs=[gr.Image(type="pil", label="Upload Room Image")], outputs=[gr.Image(type="pil", label="Redesigned Image"), gr.Textbox(label="Error Message")], title="Interior Redesign", css=custom_css ) try: iface.launch() except Exception as e: print(f"Error occurred while launching the interface: {str(e)}")