pranavajay commited on
Commit
8a49fba
·
verified ·
1 Parent(s): 5322ffd

Upload api.py

Browse files
Files changed (1) hide show
  1. api.py +720 -0
api.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import boto3
4
+ import random
5
+ import string
6
+ import numpy as np
7
+ import logging
8
+ import datetime
9
+ from fastapi import FastAPI, HTTPException, Request, Response
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, constr, conint
12
+ from diffusers import (FluxPipeline, FluxControlNetPipeline,
13
+ FluxControlNetModel, FluxImg2ImgPipeline,
14
+ FluxInpaintPipeline, CogVideoXImageToVideoPipeline)
15
+ from diffusers.utils import load_image
16
+ from PIL import Image
17
+ from collections import defaultdict
18
+ import time
19
+
20
+ # Setup logging
21
+ logging.basicConfig(level=logging.INFO,
22
+ format='%(asctime)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler("error.txt"),
25
+ logging.StreamHandler()
26
+ ])
27
+
28
+ app = FastAPI()
29
+
30
+ # Allow CORS for specific origins if needed
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"], # Update with specific domains as necessary
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ MAX_SEED = np.iinfo(np.int32).max
40
+
41
+ # AWS S3 Configuration
42
+ AWS_ACCESS_KEY_ID = "your-access-key-id"
43
+ AWS_SECRET_ACCESS_KEY = "your-secret-access-key"
44
+ AWS_REGION = "your-region"
45
+ S3_BUCKET_NAME = "your-bucket-name"
46
+
47
+ # Initialize S3 client
48
+ s3_client = boto3.client(
49
+ 's3',
50
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
51
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
52
+ region_name=AWS_REGION
53
+ )
54
+
55
+ # Asynchronously log requests
56
+ async def log_requests(user_key: str, prompt: str):
57
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
58
+ log_entry = f"{timestamp}, {user_key}, {prompt}\n"
59
+ async with aiofiles.open("key_requests.txt", "a") as log_file:
60
+ await log_file.write(log_entry)
61
+
62
+ # Asynchronously upload image to S3
63
+ async def upload_image_to_s3(image_path: str, s3_path: str):
64
+ try:
65
+ s3_client.upload_file(image_path, S3_BUCKET_NAME, s3_path)
66
+ return f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{s3_path}"
67
+ except Exception as e:
68
+ logging.error(f"Error uploading image to S3: {e}")
69
+ raise HTTPException(status_code=500, detail=f"Image upload failed: {str(e)}")
70
+
71
+ # Generate a random sequence of 12 numbers and 11 words
72
+ def generate_random_sequence():
73
+ random_numbers = ''.join(random.choices(string.digits, k=12)) # 12 random digits
74
+ random_words = ''.join(random.choices(string.ascii_lowercase, k=11)) # 11 random letters
75
+ return f"{random_numbers}_{random_words}"
76
+
77
+ # Load the default pipeline once globally for efficiency
78
+ flux_pipe = FluxPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16)
79
+ flux_pipe.enable_model_cpu_offload()
80
+ logging.info("FluxPipeline loaded successfully.")
81
+
82
+ img_pipe = FluxImg2ImgPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16)
83
+ img_pipe.enable_model_cpu_offload()
84
+ logging.info("FluxImg2ImgPipeline loaded successfully.")
85
+
86
+ inpainting_pipe = FluxInpaintPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16)
87
+ inpainting_pipe.enable_model_cpu_offload()
88
+ logging.info("FluxInpaintPipeline loaded successfully.")
89
+
90
+ video = CogVideoXImageToVideoPipeline.from_pretrained(
91
+ "THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
92
+ )
93
+ video.enable_sequential_cpu_offload()
94
+ video.vae.enable_tiling()
95
+ video.vae.enable_slicing()
96
+ logging.info("CogVideoXImageToVideoPipeline loaded successfully.")
97
+
98
+ flux_controlnet_pipe = None
99
+
100
+ # Rate limiting variables
101
+ request_timestamps = defaultdict(list) # Store timestamps of requests per user key
102
+ RATE_LIMIT = 30 # Maximum requests allowed
103
+ TIME_WINDOW = 5 # Time window in seconds
104
+
105
+ # Available LoRA styles and ControlNet adapters
106
+ style_lora_mapping = {
107
+ "Uncensored": {"path": "enhanceaiteam/Flux-uncensored", "triggered_word": "nsfw"},
108
+ "Logo": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", "triggered_word": "logo"},
109
+ "Yarn": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-MiaoKa-Yarn-World", "triggered_word": "mkym this is made of wool"},
110
+ "Anime": {"path": "prithivMLmods/Canopus-LoRA-Flux-Anime", "triggered_word": "anime"},
111
+ "Comic": {"path": "wkplhc/comic", "triggered_word": "comic"}
112
+ }
113
+
114
+ adapter_controlnet_mapping = {
115
+ "Canny": "InstantX/FLUX.1-dev-controlnet-canny",
116
+ "Depth": "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
117
+ "Pose": "Shakker-Labs/FLUX.1-dev-ControlNet-Pose",
118
+ "Upscale": "jasperai/Flux.1-dev-Controlnet-Upscaler"
119
+ }
120
+
121
+ # Request model for query parameters
122
+ class GenerateImageRequest(BaseModel):
123
+ prompt: constr(min_length=1) # Ensures prompt is not empty
124
+ guidance_scale: float = 7.5
125
+ seed: conint(ge=0, le=MAX_SEED) = 42
126
+ randomize_seed: bool = False
127
+ height: conint(gt=0) = 768
128
+ width: conint(gt=0) = 1360
129
+ control_image_url: str = "https://enhanceai.s3.amazonaws.com/792e2322-77fe-4070-aac4-7fa8d9e29c11_1.png"
130
+ controlnet_conditioning_scale: float = 0.6
131
+ num_inference_steps: conint(gt=0) = 50
132
+ num_images_per_prompt: conint(gt=0, le=5) = 1 # Limit to max 5 images per request
133
+ style: str = None # Optional LoRA style
134
+ adapter: str = None # Optional ControlNet adapter
135
+ user_key: str # API user key
136
+
137
+ # Apply LoRA style to the prompt
138
+ async def apply_lora_style(pipe, style, prompt):
139
+ if style in style_lora_mapping:
140
+ lora_path = style_lora_mapping[style]["path"]
141
+ triggered_word = style_lora_mapping[style]["triggered_word"]
142
+ pipe.load_lora_weights(lora_path)
143
+ return f"{triggered_word} {prompt}"
144
+ return prompt
145
+
146
+ # Set ControlNet adapter for the pipeline
147
+ async def set_controlnet_adapter(adapter: str, is_inpainting: bool = False):
148
+ global flux_controlnet_pipe
149
+ if adapter not in adapter_controlnet_mapping:
150
+ raise ValueError(f"Invalid ControlNet adapter: {adapter}")
151
+
152
+ controlnet_model_path = adapter_controlnet_mapping[adapter]
153
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
154
+ pipeline_cls = FluxControlNetPipeline if not is_inpainting else FluxInpaintPipeline
155
+ flux_controlnet_pipe = pipeline_cls.from_pretrained(
156
+ "pranavajay/flow", controlnet=controlnet, torch_dtype=torch.bfloat16
157
+ )
158
+ flux_controlnet_pipe.to("cuda")
159
+ logging.info(f"ControlNet adapter '{adapter}' loaded successfully.")
160
+
161
+ # Rate limit user requests
162
+ async def rate_limit(user_key: str):
163
+ current_time = time.time()
164
+ request_timestamps[user_key] = [t for t in request_timestamps[user_key] if current_time - t < TIME_WINDOW]
165
+ if len(request_timestamps[user_key]) >= RATE_LIMIT:
166
+ logging.info(f"Rate limit exceeded for user_key: {user_key}")
167
+ return False
168
+ request_timestamps[user_key].append(current_time)
169
+ return True
170
+
171
+ @app.post("/text_to_image/")
172
+ async def generate_image(req: GenerateImageRequest):
173
+ seed = req.seed or random.randint(0, MAX_SEED)
174
+
175
+ # Rate limit check
176
+ if not await rate_limit(req.user_key):
177
+ await log_requests(req.user_key, req.prompt)
178
+
179
+
180
+ retries = 3 # Number of retries for transient errors
181
+
182
+ for attempt in range(retries):
183
+ try:
184
+ # Check if prompt is None or empty
185
+ if not req.prompt or req.prompt.strip() == "":
186
+ raise ValueError("Prompt cannot be empty.")
187
+
188
+ original_prompt = req.prompt # Save the original prompt
189
+
190
+ # Set ControlNet if adapter is provided
191
+ if req.adapter:
192
+ try:
193
+ await set_controlnet_adapter(req.adapter)
194
+ except Exception as e:
195
+ logging.error(f"Error setting ControlNet adapter: {e}")
196
+ raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}")
197
+
198
+ await apply_lora_style(flux_controlnet_pipe, req.style, req.prompt)
199
+
200
+ # Load control image asynchronously
201
+ try:
202
+ loop = asyncio.get_running_loop()
203
+ control_image = await loop.run_in_executor(None, load_image, req.control_image_url)
204
+ except Exception as e:
205
+ logging.error(f"Error loading control image from URL: {e}")
206
+ raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.")
207
+
208
+ # Image generation with ControlNet
209
+ try:
210
+ if req.randomize_seed:
211
+ seed = random.randint(0, MAX_SEED)
212
+ generator = torch.Generator().manual_seed(seed)
213
+
214
+ images = await loop.run_in_executor(None, flux_controlnet_pipe, {
215
+ "prompt": req.prompt,
216
+ "guidance_scale": req.guidance_scale,
217
+ "height": req.height,
218
+ "width": req.width,
219
+ "num_inference_steps": req.num_inference_steps,
220
+ "num_images_per_prompt": req.num_images_per_prompt,
221
+ "control_image": control_image,
222
+ "generator": generator,
223
+ "controlnet_conditioning_scale": req.controlnet_conditioning_scale
224
+ })
225
+ except torch.cuda.OutOfMemoryError:
226
+ logging.error("GPU out of memory error while generating images with ControlNet.")
227
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
228
+ except Exception as e:
229
+ logging.error(f"Error during image generation with ControlNet: {e}")
230
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
231
+ else:
232
+ # Image generation without ControlNet
233
+ try:
234
+ await apply_lora_style(flux_pipe, req.style, req.prompt)
235
+ if req.randomize_seed:
236
+ seed = random.randint(0, MAX_SEED)
237
+ generator = torch.Generator().manual_seed(seed)
238
+
239
+ images = await loop.run_in_executor(None, flux_pipe, {
240
+ "prompt": req.prompt,
241
+ "guidance_scale": req.guidance_scale,
242
+ "height": req.height,
243
+ "width": req.width,
244
+ "num_inference_steps": req.num_inference_steps,
245
+ "num_images_per_prompt": req.num_images_per_prompt,
246
+ "generator": generator
247
+ })
248
+ except torch.cuda.OutOfMemoryError:
249
+ logging.error("GPU out of memory error while generating images without ControlNet.")
250
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
251
+ except Exception as e:
252
+ logging.error(f"Error during image generation without ControlNet: {e}")
253
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
254
+
255
+ # Saving images and uploading to S3 asynchronously
256
+ image_urls = []
257
+ for img in images:
258
+ image_path = f"generated_images/{generate_random_sequence()}.png"
259
+ await loop.run_in_executor(None, img.save, image_path)
260
+ image_url = await upload_image_to_s3(image_path, image_path)
261
+ image_urls.append(image_url)
262
+ os.remove(image_path) # Clean up local files after upload
263
+
264
+ return {
265
+ "status": "success",
266
+ "output": image_urls,
267
+ "prompt": original_prompt,
268
+ "height": req.height,
269
+ "width": req.width,
270
+ "scale": req.guidance_scale,
271
+ "steps": req.num_inference_steps,
272
+ "style": req.style,
273
+ "adapter": req.adapter
274
+ }
275
+
276
+ except Exception as e:
277
+ logging.error(f"Attempt {attempt + 1} failed: {e}")
278
+ if attempt == retries - 1: # Last attempt
279
+ raise HTTPException(status_code=500, detail=f"Failed to generate image after multiple attempts: {str(e)}")
280
+ continue # Retry on transient errors
281
+
282
+ class GenerateImageToImageRequest(BaseModel):
283
+ prompt: str = None # Prompt can be None
284
+ image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
285
+ strength: float = 0.7
286
+ guidance_scale: float = 7.5
287
+ seed: conint(ge=0, le=MAX_SEED) = 42
288
+ randomize_seed: bool = False
289
+ height: conint(gt=0) = 768
290
+ width: conint(gt=0) = 1360
291
+ control_image_url: str = None # Optional ControlNet image
292
+ controlnet_conditioning_scale: float = 0.6
293
+ num_inference_steps: conint(gt=0) = 50
294
+ num_images_per_prompt: conint(gt=0, le=5) = 1
295
+ style: str = None # Optional LoRA style
296
+ adapter: str = None # Optional ControlNet adapter
297
+ user_key: str # API user key
298
+
299
+ @app.post("/image_to_image/")
300
+ async def generate_image_to_image(req: GenerateImageToImageRequest):
301
+ seed = req.seed
302
+ original_prompt = req.prompt
303
+ modified_prompt = original_prompt
304
+
305
+ # Check if user is exceeding rate limit
306
+ if not await rate_limit(req.user_key):
307
+ await log_requests(req.user_key, req.prompt if req.prompt else "No prompt")
308
+ raise HTTPException(status_code=429, detail="Rate limit exceeded")
309
+
310
+ retries = 3 # Number of retries for transient errors
311
+ loop = asyncio.get_running_loop()
312
+
313
+ for attempt in range(retries):
314
+ try:
315
+ # Check if prompt is None or empty
316
+ if not req.prompt or req.prompt.strip() == "":
317
+ raise ValueError("Prompt cannot be empty.")
318
+
319
+ original_prompt = req.prompt # Save the original prompt
320
+
321
+ # Set ControlNet if adapter is provided
322
+ if req.adapter:
323
+ try:
324
+ await set_controlnet_adapter(req.adapter)
325
+ except Exception as e:
326
+ logging.error(f"Error setting ControlNet adapter: {e}")
327
+ raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}")
328
+
329
+ await apply_lora_style(flux_controlnet_pipe, req.style, req.prompt)
330
+
331
+ # Load control image asynchronously
332
+ try:
333
+ control_image = await loop.run_in_executor(None, load_image, req.control_image_url)
334
+ except Exception as e:
335
+ logging.error(f"Error loading control image from URL: {e}")
336
+ raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.")
337
+
338
+ # Image generation with ControlNet
339
+ try:
340
+ if req.randomize_seed:
341
+ seed = random.randint(0, MAX_SEED)
342
+ generator = torch.Generator().manual_seed(seed)
343
+
344
+ images = await loop.run_in_executor(None, flux_controlnet_pipe, {
345
+ "prompt": modified_prompt,
346
+ "guidance_scale": req.guidance_scale,
347
+ "height": req.height,
348
+ "width": req.width,
349
+ "num_inference_steps": req.num_inference_steps,
350
+ "num_images_per_prompt": req.num_images_per_prompt,
351
+ "control_image": control_image,
352
+ "generator": generator,
353
+ "controlnet_conditioning_scale": req.controlnet_conditioning_scale
354
+ })
355
+ except torch.cuda.OutOfMemoryError:
356
+ logging.error("GPU out of memory error while generating images with ControlNet.")
357
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
358
+ except Exception as e:
359
+ logging.error(f"Error during image generation with ControlNet: {e}")
360
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
361
+ else:
362
+ # Image generation without ControlNet
363
+ try:
364
+ await apply_lora_style(img_pipe, req.style, req.prompt)
365
+ if req.randomize_seed:
366
+ seed = random.randint(0, MAX_SEED)
367
+ generator = torch.Generator().manual_seed(seed)
368
+
369
+ source = await loop.run_in_executor(None, load_image, req.image)
370
+
371
+ images = await loop.run_in_executor(None, img_pipe, {
372
+ "prompt": modified_prompt,
373
+ "image": source,
374
+ "strength": req.strength,
375
+ "guidance_scale": req.guidance_scale,
376
+ "height": req.height,
377
+ "width": req.width,
378
+ "num_inference_steps": req.num_inference_steps,
379
+ "num_images_per_prompt": req.num_images_per_prompt,
380
+ "generator": generator
381
+ })
382
+ except torch.cuda.OutOfMemoryError:
383
+ logging.error("GPU out of memory error while generating images without ControlNet.")
384
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
385
+ except Exception as e:
386
+ logging.error(f"Error during image generation without ControlNet: {e}")
387
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
388
+
389
+ # Saving images and uploading to S3 asynchronously
390
+ image_urls = []
391
+ for img in images:
392
+ image_path = f"generated_images/{generate_random_sequence()}.png"
393
+ await loop.run_in_executor(None, img.save, image_path)
394
+ image_url = await upload_image_to_s3(image_path, image_path)
395
+ image_urls.append(image_url)
396
+ os.remove(image_path) # Clean up local files after upload
397
+
398
+ return {
399
+ "status": "success",
400
+ "output": image_urls,
401
+ "prompt": original_prompt,
402
+ "height": req.height,
403
+ "width": req.width,
404
+ "image": req.image,
405
+ "strength": req.strength,
406
+ "scale": req.guidance_scale,
407
+ "steps": req.num_inference_steps,
408
+ "style": req.style,
409
+ "adapter": req.adapter
410
+ }
411
+
412
+ except Exception as e:
413
+ logging.error(f"Attempt {attempt + 1} failed: {e}")
414
+ if attempt == retries - 1: # Last attempt
415
+ raise HTTPException(status_code=500, detail=f"Failed to generate image after multiple attempts: {str(e)}")
416
+ continue # Retry on transient errors
417
+
418
+
419
+
420
+ class GenerateInpaintingRequest(BaseModel):
421
+ prompt: str = None # Prompt can be None
422
+ image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
423
+ mask_image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
424
+ guidance_scale: float = 7.5
425
+ seed: conint(ge=0, le=MAX_SEED) = 42
426
+ randomize_seed: bool = False
427
+ height: conint(gt=0) = 768
428
+ width: conint(gt=0) = 1360
429
+ control_image_url: str = None # Optional ControlNet image
430
+ controlnet_conditioning_scale: float = 0.6
431
+ num_inference_steps: conint(gt=0) = 50
432
+ num_images_per_prompt: conint(gt=0, le=5) = 1
433
+ style: str = None # Optional LoRA style
434
+ adapter: str = None # Optional ControlNet adapter
435
+ user_key: str # API user key
436
+
437
+ @app.post("/inpainting/")
438
+ async def generate_inpainting(req: GenerateInpaintingRequest):
439
+ seed = req.seed
440
+ original_prompt = req.prompt
441
+ modified_prompt = original_prompt
442
+
443
+ # Check if user is exceeding rate limit
444
+ if not await rate_limit(req.user_key):
445
+ await log_requests(req.user_key, req.prompt if req.prompt else "No prompt")
446
+ raise HTTPException(status_code=429, detail="Rate limit exceeded")
447
+
448
+ retries = 3 # Number of retries for transient errors
449
+ loop = asyncio.get_running_loop()
450
+
451
+ for attempt in range(retries):
452
+ try:
453
+ # Check if prompt is None or empty
454
+ if not req.prompt or req.prompt.strip() == "":
455
+ raise ValueError("Prompt cannot be empty.")
456
+
457
+ # Set ControlNet if adapter is provided
458
+ if req.adapter:
459
+ try:
460
+ await set_controlnet_adapter(req.adapter, is_inpainting=True)
461
+ except Exception as e:
462
+ logging.error(f"Error setting ControlNet adapter: {e}")
463
+ raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}")
464
+
465
+ await apply_lora_style(flux_inpainting_controlnet_pipe, req.style, req.prompt)
466
+
467
+ # Load control image asynchronously
468
+ try:
469
+ control_image = await loop.run_in_executor(None, load_image, req.control_image_url)
470
+ except Exception as e:
471
+ logging.error(f"Error loading control image from URL: {e}")
472
+ raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.")
473
+
474
+ # Image generation with ControlNet
475
+ try:
476
+ if req.randomize_seed:
477
+ seed = random.randint(0, MAX_SEED)
478
+ generator = torch.Generator().manual_seed(seed)
479
+
480
+ source = await loop.run_in_executor(None, load_image, req.image)
481
+ mask = await loop.run_in_executor(None, load_image, req.mask_image)
482
+
483
+ images = await loop.run_in_executor(None, flux_controlnet_pipe, {
484
+ "prompt": modified_prompt,
485
+ "image": source,
486
+ "mask_image": mask,
487
+ "guidance_scale": req.guidance_scale,
488
+ "height": req.height,
489
+ "width": req.width,
490
+ "num_inference_steps": req.num_inference_steps,
491
+ "num_images_per_prompt": req.num_images_per_prompt,
492
+ "control_image": control_image,
493
+ "generator": generator,
494
+ "controlnet_conditioning_scale": req.controlnet_conditioning_scale
495
+ })
496
+ except torch.cuda.OutOfMemoryError:
497
+ logging.error("GPU out of memory error while generating images with ControlNet.")
498
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
499
+ except Exception as e:
500
+ logging.error(f"Error during image generation with ControlNet: {e}")
501
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
502
+ else:
503
+ # Image generation without ControlNet
504
+ try:
505
+ await apply_lora_style(inpainting_pipe, req.style, req.prompt)
506
+ if req.randomize_seed:
507
+ seed = random.randint(0, MAX_SEED)
508
+ generator = torch.Generator().manual_seed(seed)
509
+
510
+ source = await loop.run_in_executor(None, load_image, req.image)
511
+ mask = await loop.run_in_executor(None, load_image, req.mask_image)
512
+
513
+ images = await loop.run_in_executor(None, inpainting_pipe, {
514
+ "prompt": modified_prompt,
515
+ "image": source,
516
+ "mask_image": mask,
517
+ "guidance_scale": req.guidance_scale,
518
+ "height": req.height,
519
+ "width": req.width,
520
+ "num_inference_steps": req.num_inference_steps,
521
+ "num_images_per_prompt": req.num_images_per_prompt,
522
+ "generator": generator
523
+ })
524
+ except torch.cuda.OutOfMemoryError:
525
+ logging.error("GPU out of memory error while generating images without ControlNet.")
526
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
527
+ except Exception as e:
528
+ logging.error(f"Error during image generation without ControlNet: {e}")
529
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
530
+
531
+ # Saving generated images
532
+ image_urls = []
533
+ for i, img in enumerate(images):
534
+ image_path = f"generated_images/inpainting_{generate_random_sequence()}.png"
535
+ img.save(image_path)
536
+
537
+ # Optionally, upload the image to S3
538
+ s3_path = f"inpainting/{original_prompt.replace(' ', '_')}_{generate_random_sequence()}_{i}.png"
539
+ s3_url = await upload_file_to_s3(image_path, s3_path)
540
+ image_urls.append(s3_url)
541
+
542
+ # Clean up temporary files
543
+ os.remove(image_path)
544
+
545
+ return {
546
+ "status": "success",
547
+ "output": image_urls,
548
+ "prompt": original_prompt,
549
+ "height": req.height,
550
+ "width": req.width,
551
+ "scale": req.guidance_scale,
552
+ "style": req.style,
553
+ "adapter": req.adapter
554
+ }
555
+
556
+ except Exception as e:
557
+ logging.error(f"Attempt {attempt + 1} failed: {e}")
558
+ if attempt == retries - 1: # Last attempt
559
+ raise HTTPException(status_code=500, detail=f"Failed to generate inpainting after multiple attempts: {str(e)}")
560
+ continue # Retry on transient errors
561
+
562
+
563
+ class GenerateVideoRequest(BaseModel):
564
+ prompt: constr(min_length=1) # Ensures prompt is not empty
565
+ guidance_scale: float = 7.5
566
+ seed: conint(ge=0, le=MAX_SEED) = 42
567
+ randomize_seed: bool = False
568
+ height: conint(gt=0) = 768
569
+ width: conint(gt=0) = 1360
570
+ control_image_url: str = "https://enhanceai.s3.amazonaws.com/792e2322-77fe-4070-aac4-7fa8d9e29c11_1.png"
571
+ controlnet_conditioning_scale: float = 0.6
572
+ num_inference_steps: conint(gt=0) = 50
573
+ num_images_per_prompt: conint(gt=0, le=5) = 1 # Limit to max 5 images per request
574
+ style: str = None # Optional LoRA style
575
+ adapter: str = None # Optional ControlNet adapter
576
+ user_key: str # API user key
577
+
578
+
579
+ @app.post("/text_to_video/")
580
+ async def generate_video(req: GenerateImageRequest):
581
+ seed = req.seed
582
+ if not rate_limit(req.user_key):
583
+ log_requests(req.user_key, req.prompt) # Log the request when rate limit is exceeded
584
+
585
+ retries = 3 # Number of retries for transient errors
586
+ s3_urls = [] # List to store S3 URLs of generated videos
587
+ loop = asyncio.get_running_loop() # Get the current event loop
588
+
589
+ for attempt in range(retries):
590
+ try:
591
+ # Check if prompt is None or empty
592
+ if not req.prompt or req.prompt.strip() == "":
593
+ raise ValueError("Prompt cannot be empty.")
594
+
595
+ original_prompt = req.prompt # Save the original prompt
596
+
597
+ # Set ControlNet if adapter is provided
598
+ if req.adapter:
599
+ try:
600
+ await set_controlnet_adapter(req.adapter)
601
+ except Exception as e:
602
+ logging.error(f"Error setting ControlNet adapter: {e}")
603
+ raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}")
604
+
605
+ # Load control image asynchronously
606
+ try:
607
+ control_image = await loop.run_in_executor(None, load_image, req.control_image_url)
608
+ except Exception as e:
609
+ logging.error(f"Error loading control image from URL: {e}")
610
+ raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.")
611
+
612
+ # Image generation with ControlNet
613
+ try:
614
+ if req.randomize_seed:
615
+ seed = random.randint(0, MAX_SEED)
616
+ generator = torch.Generator().manual_seed(seed)
617
+
618
+ images = await loop.run_in_executor(None, flux_controlnet_pipe, {
619
+ "prompt": original_prompt,
620
+ "guidance_scale": req.guidance_scale,
621
+ "height": req.height,
622
+ "width": req.width,
623
+ "num_inference_steps": req.num_inference_steps,
624
+ "num_images_per_prompt": req.num_images_per_prompt,
625
+ "control_image": control_image,
626
+ "generator": generator,
627
+ "controlnet_conditioning_scale": req.controlnet_conditioning_scale
628
+ })
629
+ except torch.cuda.OutOfMemoryError:
630
+ logging.error("GPU out of memory error while generating images with ControlNet.")
631
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
632
+ except Exception as e:
633
+ logging.error(f"Error during image generation with ControlNet: {e}")
634
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
635
+ else:
636
+ # Image generation without ControlNet
637
+ try:
638
+ await apply_lora_style(flux_pipe, req.style, req.prompt)
639
+ if req.randomize_seed:
640
+ seed = random.randint(0, MAX_SEED)
641
+ generator = torch.Generator().manual_seed(seed)
642
+
643
+ images = await loop.run_in_executor(None, flux_pipe, {
644
+ "prompt": original_prompt,
645
+ "guidance_scale": req.guidance_scale,
646
+ "height": req.height,
647
+ "width": req.width,
648
+ "num_inference_steps": req.num_inference_steps,
649
+ "num_images_per_prompt": req.num_images_per_prompt,
650
+ "generator": generator
651
+ })
652
+ except torch.cuda.OutOfMemoryError:
653
+ logging.error("GPU out of memory error while generating images without ControlNet.")
654
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
655
+ except Exception as e:
656
+ logging.error(f"Error during image generation without ControlNet: {e}")
657
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
658
+
659
+ # Saving images and uploading to S3
660
+ for i, img in enumerate(images):
661
+ image_path = f"generated_images/{generate_random_sequence()}.png"
662
+
663
+ # Save image asynchronously
664
+ await loop.run_in_executor(None, img.save, image_path)
665
+
666
+ # Generate video from the image
667
+ if req.randomize_seed:
668
+ seed = random.randint(0, MAX_SEED)
669
+ vido = await loop.run_in_executor(None, video, {
670
+ "prompt": original_prompt,
671
+ "image": image_path,
672
+ "num_videos_per_prompt": 1,
673
+ "num_inference_steps": req.num_inference_steps,
674
+ "num_frames": req.num_frames,
675
+ "guidance_scale": req.guidance_scale,
676
+ "generator": torch.Generator(device="cuda").manual_seed(seed)
677
+ })
678
+
679
+ # Export the video to a file asynchronously
680
+ video_path = f"generated_video_{i}_{generate_random_sequence()}.mp4"
681
+ await loop.run_in_executor(None, export_to_video, vido, video_path, 8)
682
+
683
+ # Upload the video to S3 asynchronously
684
+ s3_path = f"videos/{original_prompt.replace(' ', '_')}_{generate_random_sequence()}_{i}.mp4"
685
+ s3_url = await loop.run_in_executor(None, upload_file_to_s3, video_path, s3_path)
686
+ s3_urls.append(s3_url)
687
+
688
+ # Clean up temporary files
689
+ os.remove(image_path)
690
+ os.remove(video_path)
691
+
692
+ return {
693
+ "status": "success",
694
+ "output": s3_urls,
695
+ "prompt": original_prompt,
696
+ "height": req.height,
697
+ "width": req.width,
698
+ "num_frames": req.num_frames,
699
+ "scale": req.guidance_scale,
700
+ "style": req.style,
701
+ "adapter": req.adapter
702
+ }
703
+
704
+ except Exception as e:
705
+ logging.error(f"Attempt {attempt + 1} failed: {e}")
706
+ if attempt == retries - 1: # Last attempt
707
+ raise HTTPException(status_code=500, detail=f"Failed to generate video after multiple attempts: {str(e)}")
708
+ continue # Retry on transient errors
709
+
710
+ @app.on_event("shutdown")
711
+ def shutdown_event():
712
+ """ Perform any cleanup activities on shutdown. """
713
+ logging.info("Shutting down the application gracefully.")
714
+
715
+ # Additional endpoints can be added as needed, such as image-to-image or inpainting.
716
+
717
+ if __name__ == "__main__":
718
+ import uvicorn
719
+ uvicorn.run(app, host="0.0.0.0", port=8000)
720
+