|
import os |
|
import torch |
|
import boto3 |
|
import random |
|
import string |
|
import numpy as np |
|
import logging |
|
import datetime |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, HTTPException, Request, Response |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, constr, conint |
|
from diffusers import FluxPipeline |
|
from diffusers.pipelines import FluxImg2ImgPipeline |
|
from diffusers.pipelines import FluxInpaintPipeline |
|
from diffusers import CogVideoXImageToVideoPipeline |
|
from diffusers.pipelines import FluxControlNetPipeline |
|
from diffusers.pipelines import FluxControlNetInpaintPipeline |
|
from diffusers.models import FluxControlNetModel |
|
from diffusers.utils import load_image |
|
from PIL import Image |
|
from collections import defaultdict |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler("error.txt"), |
|
logging.StreamHandler() |
|
]) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
AWS_ACCESS_KEY_ID = "your-access-key-id" |
|
AWS_SECRET_ACCESS_KEY = "your-secret-access-key" |
|
AWS_REGION = "your-region" |
|
S3_BUCKET_NAME = "your-bucket-name" |
|
|
|
|
|
s3_client = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY, |
|
region_name=AWS_REGION |
|
) |
|
|
|
|
|
async def log_requests(user_key: str, prompt: str): |
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
log_entry = f"{timestamp}, {user_key}, {prompt}\n" |
|
async with aiofiles.open("key_requests.txt", "a") as log_file: |
|
await log_file.write(log_entry) |
|
|
|
|
|
async def upload_image_to_s3(image_path: str, s3_path: str): |
|
try: |
|
s3_client.upload_file(image_path, S3_BUCKET_NAME, s3_path) |
|
return f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{s3_path}" |
|
except Exception as e: |
|
logging.error(f"Error uploading image to S3: {e}") |
|
raise HTTPException(status_code=500, detail=f"Image upload failed: {str(e)}") |
|
|
|
|
|
def generate_random_sequence(): |
|
random_numbers = ''.join(random.choices(string.digits, k=12)) |
|
random_words = ''.join(random.choices(string.ascii_lowercase, k=11)) |
|
return f"{random_numbers}_{random_words}" |
|
|
|
|
|
flux_pipe = FluxPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16) |
|
flux_pipe.enable_model_cpu_offload() |
|
logging.info("FluxPipeline loaded successfully.") |
|
|
|
img_pipe = FluxImg2ImgPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16) |
|
img_pipe.enable_model_cpu_offload() |
|
logging.info("FluxImg2ImgPipeline loaded successfully.") |
|
|
|
inpainting_pipe = FluxImg2ImgPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16) |
|
inpainting_pipe.enable_model_cpu_offload() |
|
logging.info("FluxInpaintPipeline loaded successfully.") |
|
|
|
video = CogVideoXImageToVideoPipeline.from_pretrained( |
|
"THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16 |
|
) |
|
video.enable_sequential_cpu_offload() |
|
video.vae.enable_tiling() |
|
video.vae.enable_slicing() |
|
logging.info("CogVideoXImageToVideoPipeline loaded successfully.") |
|
|
|
flux_controlnet_pipe = None |
|
|
|
|
|
request_timestamps = defaultdict(list) |
|
RATE_LIMIT = 30 |
|
TIME_WINDOW = 5 |
|
|
|
|
|
style_lora_mapping = { |
|
"Uncensored": {"path": "enhanceaiteam/Flux-uncensored", "triggered_word": "nsfw"}, |
|
"Logo": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", "triggered_word": "logo"}, |
|
"Yarn": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-MiaoKa-Yarn-World", "triggered_word": "mkym this is made of wool"}, |
|
"Anime": {"path": "prithivMLmods/Canopus-LoRA-Flux-Anime", "triggered_word": "anime"}, |
|
"Comic": {"path": "wkplhc/comic", "triggered_word": "comic"} |
|
} |
|
|
|
adapter_controlnet_mapping = { |
|
"Canny": "InstantX/FLUX.1-dev-controlnet-canny", |
|
"Depth": "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", |
|
"Pose": "Shakker-Labs/FLUX.1-dev-ControlNet-Pose", |
|
"Upscale": "jasperai/Flux.1-dev-Controlnet-Upscaler" |
|
} |
|
|
|
|
|
class GenerateImageRequest(BaseModel): |
|
prompt: constr(min_length=1) |
|
guidance_scale: float = 7.5 |
|
seed: conint(ge=0, le=MAX_SEED) = 42 |
|
randomize_seed: bool = False |
|
height: conint(gt=0) = 768 |
|
width: conint(gt=0) = 1360 |
|
control_image_url: str = "https://enhanceai.s3.amazonaws.com/792e2322-77fe-4070-aac4-7fa8d9e29c11_1.png" |
|
controlnet_conditioning_scale: float = 0.6 |
|
num_inference_steps: conint(gt=0) = 50 |
|
num_images_per_prompt: conint(gt=0, le=5) = 1 |
|
style: str = None |
|
adapter: str = None |
|
user_key: str |
|
|
|
|
|
async def apply_lora_style(pipe, style, prompt): |
|
if style in style_lora_mapping: |
|
lora_path = style_lora_mapping[style]["path"] |
|
triggered_word = style_lora_mapping[style]["triggered_word"] |
|
pipe.load_lora_weights(lora_path) |
|
return f"{triggered_word} {prompt}" |
|
return prompt |
|
|
|
|
|
async def set_controlnet_adapter(adapter: str, is_inpainting: bool = False): |
|
global flux_controlnet_pipe |
|
if adapter not in adapter_controlnet_mapping: |
|
raise ValueError(f"Invalid ControlNet adapter: {adapter}") |
|
|
|
controlnet_model_path = adapter_controlnet_mapping[adapter] |
|
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16) |
|
pipeline_cls = FluxControlNetPipeline if not is_inpainting else FluxControlNetInpaintPipeline |
|
flux_controlnet_pipe = pipeline_cls.from_pretrained( |
|
"pranavajay/flow", controlnet=controlnet, torch_dtype=torch.bfloat16 |
|
) |
|
flux_controlnet_pipe.to("cuda") |
|
logging.info(f"ControlNet adapter '{adapter}' loaded successfully.") |
|
|
|
|
|
async def rate_limit(user_key: str): |
|
current_time = time.time() |
|
request_timestamps[user_key] = [t for t in request_timestamps[user_key] if current_time - t < TIME_WINDOW] |
|
if len(request_timestamps[user_key]) >= RATE_LIMIT: |
|
logging.info(f"Rate limit exceeded for user_key: {user_key}") |
|
return False |
|
request_timestamps[user_key].append(current_time) |
|
return True |
|
|
|
@app.post("/text_to_image/") |
|
async def generate_image(req: GenerateImageRequest): |
|
seed = req.seed or random.randint(0, MAX_SEED) |
|
|
|
|
|
if not await rate_limit(req.user_key): |
|
await log_requests(req.user_key, req.prompt) |
|
|
|
|
|
retries = 3 |
|
|
|
for attempt in range(retries): |
|
try: |
|
|
|
if not req.prompt or req.prompt.strip() == "": |
|
raise ValueError("Prompt cannot be empty.") |
|
|
|
original_prompt = req.prompt |
|
|
|
|
|
if req.adapter: |
|
try: |
|
await set_controlnet_adapter(req.adapter) |
|
except Exception as e: |
|
logging.error(f"Error setting ControlNet adapter: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}") |
|
|
|
await apply_lora_style(flux_controlnet_pipe, req.style, req.prompt) |
|
|
|
|
|
try: |
|
loop = asyncio.get_running_loop() |
|
control_image = await loop.run_in_executor(None, load_image, req.control_image_url) |
|
except Exception as e: |
|
logging.error(f"Error loading control image from URL: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.") |
|
|
|
|
|
try: |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = await loop.run_in_executor(None, flux_controlnet_pipe, { |
|
"prompt": req.prompt, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"control_image": control_image, |
|
"generator": generator, |
|
"controlnet_conditioning_scale": req.controlnet_conditioning_scale |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images with ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation with ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
else: |
|
|
|
try: |
|
await apply_lora_style(flux_pipe, req.style, req.prompt) |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = await loop.run_in_executor(None, flux_pipe, { |
|
"prompt": req.prompt, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"generator": generator |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images without ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation without ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
|
|
|
|
image_urls = [] |
|
for img in images: |
|
image_path = f"generated_images/{generate_random_sequence()}.png" |
|
await loop.run_in_executor(None, img.save, image_path) |
|
image_url = await upload_image_to_s3(image_path, image_path) |
|
image_urls.append(image_url) |
|
os.remove(image_path) |
|
|
|
return { |
|
"status": "success", |
|
"output": image_urls, |
|
"prompt": original_prompt, |
|
"height": req.height, |
|
"width": req.width, |
|
"scale": req.guidance_scale, |
|
"steps": req.num_inference_steps, |
|
"style": req.style, |
|
"adapter": req.adapter |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Attempt {attempt + 1} failed: {e}") |
|
if attempt == retries - 1: |
|
raise HTTPException(status_code=500, detail=f"Failed to generate image after multiple attempts: {str(e)}") |
|
continue |
|
|
|
class GenerateImageToImageRequest(BaseModel): |
|
prompt: str = None |
|
image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
|
strength: float = 0.7 |
|
guidance_scale: float = 7.5 |
|
seed: conint(ge=0, le=MAX_SEED) = 42 |
|
randomize_seed: bool = False |
|
height: conint(gt=0) = 768 |
|
width: conint(gt=0) = 1360 |
|
control_image_url: str = None |
|
controlnet_conditioning_scale: float = 0.6 |
|
num_inference_steps: conint(gt=0) = 50 |
|
num_images_per_prompt: conint(gt=0, le=5) = 1 |
|
style: str = None |
|
adapter: str = None |
|
user_key: str |
|
|
|
@app.post("/image_to_image/") |
|
async def generate_image_to_image(req: GenerateImageToImageRequest): |
|
seed = req.seed |
|
original_prompt = req.prompt |
|
modified_prompt = original_prompt |
|
|
|
|
|
if not await rate_limit(req.user_key): |
|
await log_requests(req.user_key, req.prompt if req.prompt else "No prompt") |
|
raise HTTPException(status_code=429, detail="Rate limit exceeded") |
|
|
|
retries = 3 |
|
loop = asyncio.get_running_loop() |
|
|
|
for attempt in range(retries): |
|
try: |
|
|
|
if not req.prompt or req.prompt.strip() == "": |
|
raise ValueError("Prompt cannot be empty.") |
|
|
|
original_prompt = req.prompt |
|
|
|
|
|
if req.adapter: |
|
try: |
|
await set_controlnet_adapter(req.adapter) |
|
except Exception as e: |
|
logging.error(f"Error setting ControlNet adapter: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}") |
|
|
|
await apply_lora_style(flux_controlnet_pipe, req.style, req.prompt) |
|
|
|
|
|
try: |
|
control_image = await loop.run_in_executor(None, load_image, req.control_image_url) |
|
except Exception as e: |
|
logging.error(f"Error loading control image from URL: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.") |
|
|
|
|
|
try: |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = await loop.run_in_executor(None, flux_controlnet_pipe, { |
|
"prompt": modified_prompt, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"control_image": control_image, |
|
"generator": generator, |
|
"controlnet_conditioning_scale": req.controlnet_conditioning_scale |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images with ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation with ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
else: |
|
|
|
try: |
|
await apply_lora_style(img_pipe, req.style, req.prompt) |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
source = await loop.run_in_executor(None, load_image, req.image) |
|
|
|
images = await loop.run_in_executor(None, img_pipe, { |
|
"prompt": modified_prompt, |
|
"image": source, |
|
"strength": req.strength, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"generator": generator |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images without ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation without ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
|
|
|
|
image_urls = [] |
|
for img in images: |
|
image_path = f"generated_images/{generate_random_sequence()}.png" |
|
await loop.run_in_executor(None, img.save, image_path) |
|
image_url = await upload_image_to_s3(image_path, image_path) |
|
image_urls.append(image_url) |
|
os.remove(image_path) |
|
|
|
return { |
|
"status": "success", |
|
"output": image_urls, |
|
"prompt": original_prompt, |
|
"height": req.height, |
|
"width": req.width, |
|
"image": req.image, |
|
"strength": req.strength, |
|
"scale": req.guidance_scale, |
|
"steps": req.num_inference_steps, |
|
"style": req.style, |
|
"adapter": req.adapter |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Attempt {attempt + 1} failed: {e}") |
|
if attempt == retries - 1: |
|
raise HTTPException(status_code=500, detail=f"Failed to generate image after multiple attempts: {str(e)}") |
|
continue |
|
|
|
|
|
|
|
class GenerateInpaintingRequest(BaseModel): |
|
prompt: str = None |
|
image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
|
mask_image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" |
|
guidance_scale: float = 7.5 |
|
seed: conint(ge=0, le=MAX_SEED) = 42 |
|
randomize_seed: bool = False |
|
height: conint(gt=0) = 768 |
|
width: conint(gt=0) = 1360 |
|
control_image_url: str = None |
|
controlnet_conditioning_scale: float = 0.6 |
|
num_inference_steps: conint(gt=0) = 50 |
|
num_images_per_prompt: conint(gt=0, le=5) = 1 |
|
style: str = None |
|
adapter: str = None |
|
user_key: str |
|
|
|
@app.post("/inpainting/") |
|
async def generate_inpainting(req: GenerateInpaintingRequest): |
|
seed = req.seed |
|
original_prompt = req.prompt |
|
modified_prompt = original_prompt |
|
|
|
|
|
if not await rate_limit(req.user_key): |
|
await log_requests(req.user_key, req.prompt if req.prompt else "No prompt") |
|
raise HTTPException(status_code=429, detail="Rate limit exceeded") |
|
|
|
retries = 3 |
|
loop = asyncio.get_running_loop() |
|
|
|
for attempt in range(retries): |
|
try: |
|
|
|
if not req.prompt or req.prompt.strip() == "": |
|
raise ValueError("Prompt cannot be empty.") |
|
|
|
|
|
if req.adapter: |
|
try: |
|
await set_controlnet_adapter(req.adapter, is_inpainting=True) |
|
except Exception as e: |
|
logging.error(f"Error setting ControlNet adapter: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}") |
|
|
|
await apply_lora_style(flux_inpainting_controlnet_pipe, req.style, req.prompt) |
|
|
|
|
|
try: |
|
control_image = await loop.run_in_executor(None, load_image, req.control_image_url) |
|
except Exception as e: |
|
logging.error(f"Error loading control image from URL: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.") |
|
|
|
|
|
try: |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
source = await loop.run_in_executor(None, load_image, req.image) |
|
mask = await loop.run_in_executor(None, load_image, req.mask_image) |
|
|
|
images = await loop.run_in_executor(None, flux_controlnet_pipe, { |
|
"prompt": modified_prompt, |
|
"image": source, |
|
"mask_image": mask, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"control_image": control_image, |
|
"generator": generator, |
|
"controlnet_conditioning_scale": req.controlnet_conditioning_scale |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images with ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation with ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
else: |
|
|
|
try: |
|
await apply_lora_style(inpainting_pipe, req.style, req.prompt) |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
source = await loop.run_in_executor(None, load_image, req.image) |
|
mask = await loop.run_in_executor(None, load_image, req.mask_image) |
|
|
|
images = await loop.run_in_executor(None, inpainting_pipe, { |
|
"prompt": modified_prompt, |
|
"image": source, |
|
"mask_image": mask, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"generator": generator |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images without ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation without ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
|
|
|
|
image_urls = [] |
|
for i, img in enumerate(images): |
|
image_path = f"generated_images/inpainting_{generate_random_sequence()}.png" |
|
img.save(image_path) |
|
|
|
|
|
s3_path = f"inpainting/{original_prompt.replace(' ', '_')}_{generate_random_sequence()}_{i}.png" |
|
s3_url = await upload_file_to_s3(image_path, s3_path) |
|
image_urls.append(s3_url) |
|
|
|
|
|
os.remove(image_path) |
|
|
|
return { |
|
"status": "success", |
|
"output": image_urls, |
|
"prompt": original_prompt, |
|
"height": req.height, |
|
"width": req.width, |
|
"scale": req.guidance_scale, |
|
"style": req.style, |
|
"adapter": req.adapter |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Attempt {attempt + 1} failed: {e}") |
|
if attempt == retries - 1: |
|
raise HTTPException(status_code=500, detail=f"Failed to generate inpainting after multiple attempts: {str(e)}") |
|
continue |
|
|
|
|
|
class GenerateVideoRequest(BaseModel): |
|
prompt: constr(min_length=1) |
|
guidance_scale: float = 7.5 |
|
seed: conint(ge=0, le=MAX_SEED) = 42 |
|
randomize_seed: bool = False |
|
height: conint(gt=0) = 768 |
|
width: conint(gt=0) = 1360 |
|
control_image_url: str = "https://enhanceai.s3.amazonaws.com/792e2322-77fe-4070-aac4-7fa8d9e29c11_1.png" |
|
controlnet_conditioning_scale: float = 0.6 |
|
num_inference_steps: conint(gt=0) = 50 |
|
num_images_per_prompt: conint(gt=0, le=5) = 1 |
|
style: str = None |
|
adapter: str = None |
|
user_key: str |
|
|
|
|
|
@app.post("/text_to_video/") |
|
async def generate_video(req: GenerateImageRequest): |
|
seed = req.seed |
|
if not rate_limit(req.user_key): |
|
log_requests(req.user_key, req.prompt) |
|
|
|
retries = 3 |
|
s3_urls = [] |
|
loop = asyncio.get_running_loop() |
|
|
|
for attempt in range(retries): |
|
try: |
|
|
|
if not req.prompt or req.prompt.strip() == "": |
|
raise ValueError("Prompt cannot be empty.") |
|
|
|
original_prompt = req.prompt |
|
|
|
|
|
if req.adapter: |
|
try: |
|
await set_controlnet_adapter(req.adapter) |
|
except Exception as e: |
|
logging.error(f"Error setting ControlNet adapter: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}") |
|
|
|
|
|
try: |
|
control_image = await loop.run_in_executor(None, load_image, req.control_image_url) |
|
except Exception as e: |
|
logging.error(f"Error loading control image from URL: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.") |
|
|
|
|
|
try: |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = await loop.run_in_executor(None, flux_controlnet_pipe, { |
|
"prompt": original_prompt, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"control_image": control_image, |
|
"generator": generator, |
|
"controlnet_conditioning_scale": req.controlnet_conditioning_scale |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images with ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation with ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
else: |
|
|
|
try: |
|
await apply_lora_style(flux_pipe, req.style, req.prompt) |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = await loop.run_in_executor(None, flux_pipe, { |
|
"prompt": original_prompt, |
|
"guidance_scale": req.guidance_scale, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_images_per_prompt": req.num_images_per_prompt, |
|
"generator": generator |
|
}) |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images without ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation without ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
|
|
|
|
for i, img in enumerate(images): |
|
image_path = f"generated_images/{generate_random_sequence()}.png" |
|
|
|
|
|
await loop.run_in_executor(None, img.save, image_path) |
|
|
|
|
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
vido = await loop.run_in_executor(None, video, { |
|
"prompt": original_prompt, |
|
"image": image_path, |
|
"num_videos_per_prompt": 1, |
|
"num_inference_steps": req.num_inference_steps, |
|
"num_frames": req.num_frames, |
|
"guidance_scale": req.guidance_scale, |
|
"generator": torch.Generator(device="cuda").manual_seed(seed) |
|
}) |
|
|
|
|
|
video_path = f"generated_video_{i}_{generate_random_sequence()}.mp4" |
|
await loop.run_in_executor(None, export_to_video, vido, video_path, 8) |
|
|
|
|
|
s3_path = f"videos/{original_prompt.replace(' ', '_')}_{generate_random_sequence()}_{i}.mp4" |
|
s3_url = await loop.run_in_executor(None, upload_file_to_s3, video_path, s3_path) |
|
s3_urls.append(s3_url) |
|
|
|
|
|
os.remove(image_path) |
|
os.remove(video_path) |
|
|
|
return { |
|
"status": "success", |
|
"output": s3_urls, |
|
"prompt": original_prompt, |
|
"height": req.height, |
|
"width": req.width, |
|
"num_frames": req.num_frames, |
|
"scale": req.guidance_scale, |
|
"style": req.style, |
|
"adapter": req.adapter |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Attempt {attempt + 1} failed: {e}") |
|
if attempt == retries - 1: |
|
raise HTTPException(status_code=500, detail=f"Failed to generate video after multiple attempts: {str(e)}") |
|
continue |
|
|
|
@asynccontextmanager |
|
@app.on_event("shutdown") |
|
def shutdown_event(): |
|
""" Perform any cleanup activities on shutdown. """ |
|
logging.info("Shutting down the application gracefully.") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|