Msaqibsharif commited on
Commit
50dfb0d
1 Parent(s): 810681a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import traceback
6
+ import gradio as gr
7
+ from transformers import DetrImageProcessor, DetrForObjectDetection, LayoutLMTokenizer, LayoutLMForTokenClassification
8
+ from diffusers import StableDiffusionPipeline, StableDiffusionUpscalePipeline
9
+ from huggingface_hub import login
10
+ import torchvision.transforms as T
11
+ import torchvision.models as models
12
+ from dotenv import load_dotenv
13
+
14
+ # Load environment variables from .env file
15
+ load_dotenv()
16
+
17
+ # Retrieve Hugging Face token from environment variable
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+
20
+ ## 2.1 Image Analysis with DETR
21
+ def load_detr_model():
22
+ try:
23
+ detr_model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
24
+ detr_processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
25
+ return detr_model, detr_processor, None
26
+ except Exception as e:
27
+ return None, None, f"Error loading DETR model: {e}"
28
+
29
+ detr_model, detr_processor, detr_error = load_detr_model()
30
+
31
+ def detect_objects(image):
32
+ if detr_model is not None and detr_processor is not None:
33
+ try:
34
+ inputs = detr_processor(images=image, return_tensors="pt")
35
+ outputs = detr_model(**inputs)
36
+ target_sizes = torch.tensor([image.size[::-1]])
37
+ results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
38
+ return results, None
39
+ except Exception as e:
40
+ return None, f"Error in detect_objects: {e}"
41
+ else:
42
+ return None, "DETR models not loaded. Skipping object detection."
43
+
44
+ ## 2.2 Style Transfer with Deep Image Prior
45
+ def style_transfer(content_image, style_image):
46
+ try:
47
+ transform = T.Compose([
48
+ T.Resize((512, 512)),
49
+ T.ToTensor(),
50
+ T.Lambda(lambda x: x.mul(255))
51
+ ])
52
+
53
+ content = transform(content_image).unsqueeze(0).requires_grad_(False)
54
+ style = transform(style_image).unsqueeze(0).requires_grad_(False)
55
+
56
+ vgg = models.vgg19(pretrained=True).features.eval()
57
+ for param in vgg.parameters():
58
+ param.requires_grad_(False)
59
+
60
+ generated = content.clone().requires_grad_(True)
61
+ optimizer = torch.optim.Adam([generated], lr=0.003)
62
+
63
+ for i in range(300):
64
+ generated_features = vgg(generated)
65
+ content_features = vgg(content)
66
+ style_features = vgg(style)
67
+
68
+ content_loss = torch.mean((generated_features - content_features)**2)
69
+ style_loss = torch.mean((generated_features - style_features)**2)
70
+
71
+ total_loss = content_loss + style_loss
72
+ optimizer.zero_grad()
73
+ total_loss.backward()
74
+ optimizer.step()
75
+
76
+ generated_image = generated.squeeze().clamp(0, 255).cpu().detach().numpy().transpose(1, 2, 0)
77
+ return Image.fromarray(np.uint8(generated_image)), None
78
+ except Exception as e:
79
+ return content_image, f"Error in style_transfer: {e}"
80
+
81
+ ## 2.3 Layout Generation with LayoutLM
82
+ def load_layoutlm_model():
83
+ try:
84
+ layoutlm_tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
85
+ layoutlm_model = LayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased')
86
+ return layoutlm_tokenizer, layoutlm_model, None
87
+ except Exception as e:
88
+ return None, None, f"Error loading LayoutLM model: {e}"
89
+
90
+ layoutlm_tokenizer, layoutlm_model, layoutlm_error = load_layoutlm_model()
91
+
92
+ def generate_layout(text):
93
+ if layoutlm_tokenizer is not None and layoutlm_model is not None:
94
+ try:
95
+ inputs = layoutlm_tokenizer(text, return_tensors="pt")
96
+ outputs = layoutlm_model(**inputs)
97
+ layout = outputs.logits.argmax(dim=-1)
98
+ return layout, None
99
+ except Exception as e:
100
+ return None, f"Error in generate_layout: {e}"
101
+ else:
102
+ return None, "LayoutLM models not loaded. Skipping layout generation."
103
+
104
+ ## 2.4 Image Generation with Stable Diffusion
105
+ def load_stable_diffusion_model():
106
+ try:
107
+ if HF_TOKEN is None:
108
+ raise ValueError("Hugging Face token not found in environment variables.")
109
+ login(token=HF_TOKEN)
110
+ sd_pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
111
+ return sd_pipeline, None
112
+ except Exception as e:
113
+ return None, f"Error loading Stable Diffusion model: {e}"
114
+
115
+ sd_pipeline, sd_error = load_stable_diffusion_model()
116
+
117
+ def generate_image(prompt):
118
+ if sd_pipeline is not None:
119
+ try:
120
+ image = sd_pipeline(prompt).images[0]
121
+ return image, None
122
+ except Exception as e:
123
+ return None, f"Error in generate_image: {e}"
124
+ else:
125
+ return None, "Stable Diffusion model not loaded. Skipping image generation."
126
+
127
+ ## 2.5 Super-Resolution
128
+ def load_upscale_pipeline():
129
+ try:
130
+ upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler").to("cuda")
131
+ return upscale_pipeline, None
132
+ except Exception as e:
133
+ return None, f"Error loading Upscale Pipeline: {e}"
134
+
135
+ upscale_pipeline, upscale_error = load_upscale_pipeline()
136
+
137
+ def super_resolve(image):
138
+ if upscale_pipeline is not None:
139
+ try:
140
+ if not isinstance(image, Image.Image):
141
+ raise ValueError("Input must be a PIL image.")
142
+ upscaled_image = upscale_pipeline(image=image).images[0]
143
+ return upscaled_image, None
144
+ except Exception as e:
145
+ return None, f"Error in super_resolve: {e}"
146
+ else:
147
+ return image, "Upscale Pipeline not loaded. Skipping super-resolution."
148
+
149
+ # Step 3: Gradio Interface and Integration
150
+ def process_image(image, style_image, text_prompt):
151
+ try:
152
+ # Detect objects
153
+ object_results, detect_error = detect_objects(image)
154
+ if detect_error:
155
+ return None, detect_error
156
+
157
+ # Style transfer
158
+ styled_image, style_error = style_transfer(image, style_image)
159
+ if style_error:
160
+ return None, style_error
161
+
162
+ # Generate layout
163
+ layout_results, layout_error = generate_layout(text_prompt)
164
+ if layout_error:
165
+ return None, layout_error
166
+
167
+ # Generate image based on layout
168
+ generated_image, gen_image_error = generate_image("modern interior design based on layout")
169
+ if gen_image_error:
170
+ return None, gen_image_error
171
+
172
+ # Super-resolve the generated image
173
+ final_image, upscale_error = super_resolve(generated_image)
174
+ if upscale_error:
175
+ return None, upscale_error
176
+
177
+ return final_image, None
178
+ except Exception as e:
179
+ return None, f"Error in process_image: {e}"
180
+
181
+ iface = gr.Interface(
182
+ fn=process_image,
183
+ inputs=[
184
+ gr.Image(type="pil", label="Upload Room Image"),
185
+ gr.Image(type="pil", label="Upload Style Image"),
186
+ gr.Textbox(label="Enter Design Prompt")
187
+ ],
188
+ outputs=[
189
+ gr.Image(type="pil", label="Generated Image"),
190
+ gr.Textbox(label="Error Message")
191
+ ]
192
+ )
193
+
194
+ try:
195
+ iface.launch()
196
+ except Exception as e:
197
+ print(f"Error occurred while launching the interface: {e}")
198
+ traceback.print_exc()