|
|
|
import gradio as gr |
|
import os |
|
import time |
|
|
|
from torch import cuda |
|
import random |
|
from utils.ai_generator_diffusers_flux import generate_ai_image_local |
|
|
|
from huggingface_hub import InferenceClient |
|
import requests |
|
import io |
|
from PIL import Image |
|
from tempfile import NamedTemporaryFile |
|
import utils.constants as constants |
|
|
|
def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512, progress=gr.Progress(track_tqdm=True)): |
|
|
|
client = InferenceClient() |
|
|
|
response = client(text, model_name) |
|
|
|
image_data = response.content |
|
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
|
image = image.resize((image_width, image_height)) |
|
return image |
|
|
|
def generate_ai_image( |
|
map_option, |
|
prompt_textbox_value, |
|
neg_prompt_textbox_value, |
|
model, |
|
lora_weights=None, |
|
conditioned_image=None, |
|
pipeline = "FluxPipeline", |
|
width=912, |
|
height=512, |
|
strength=0.5, |
|
seed = 0, |
|
progress=gr.Progress(track_tqdm=True), |
|
*args, |
|
**kwargs |
|
): |
|
if seed == 0: |
|
seed = random.randint(0, constants.MAX_SEED) |
|
if (cuda.is_available() and cuda.device_count() >= 1): |
|
print("Local GPU available. Generating image locally.") |
|
if conditioned_image is not None: |
|
pipeline = "FluxImg2ImgPipeline" |
|
return generate_ai_image_local( |
|
map_option, |
|
prompt_textbox_value, |
|
neg_prompt_textbox_value, |
|
model, |
|
lora_weights=lora_weights, |
|
seed=seed, |
|
conditioned_image=conditioned_image, |
|
pipeline_name=pipeline, |
|
strength=strength, |
|
height=height, |
|
width=width |
|
) |
|
else: |
|
print("No local GPU available. Sending request to Hugging Face API.") |
|
return generate_ai_image_remote( |
|
map_option, |
|
prompt_textbox_value, |
|
neg_prompt_textbox_value, |
|
model, |
|
height=height, |
|
width=width, |
|
seed=seed |
|
) |
|
|
|
def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777,progress=gr.Progress(track_tqdm=True)): |
|
max_retries = 3 |
|
retry_delay = 4 |
|
|
|
try: |
|
if map_option != "Prompt": |
|
prompt = constants.PROMPTS[map_option] |
|
|
|
negative_prompt_str = constants.NEGATIVE_PROMPTS.get(map_option, "") |
|
negative_prompt = [p.strip() for p in negative_prompt_str.split(',') if p.strip()] |
|
else: |
|
prompt = prompt_textbox_value |
|
|
|
negative_prompt = [p.strip() for p in neg_prompt_textbox_value.split(',') if p.strip()] if neg_prompt_textbox_value else [] |
|
|
|
print("Remotely Generating image with the following parameters:") |
|
print(f"Prompt: {prompt}") |
|
print(f"Negative Prompt: {negative_prompt}") |
|
print(f"Height: {height}") |
|
print(f"Width: {width}") |
|
print(f"Number of Inference Steps: {num_inference_steps}") |
|
print(f"Guidance Scale: {guidance_scale}") |
|
print(f"Seed: {seed}") |
|
|
|
for attempt in range(1, max_retries + 1): |
|
try: |
|
if os.getenv("IS_SHARED_SPACE") == "True": |
|
client = InferenceClient( |
|
model, |
|
token=constants.HF_API_TOKEN |
|
) |
|
image = client.text_to_image( |
|
inputs=prompt, |
|
parameters={ |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"width": width, |
|
"height": height, |
|
"max_sequence_length":512, |
|
|
|
"seed": seed |
|
} |
|
) |
|
else: |
|
API_URL = f"https://api-inference.huggingface.co/models/{model}" |
|
headers = { |
|
"Authorization": f"Bearer {constants.HF_API_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"width": width, |
|
"height": height, |
|
"max_sequence_length":512, |
|
|
|
"seed": seed |
|
} |
|
} |
|
|
|
print(f"Attempt {attempt}: Sending POST request to Hugging Face API...") |
|
response = requests.post(API_URL, headers=headers, json=payload, timeout=300) |
|
if response.status_code == 200: |
|
image_bytes = response.content |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
break |
|
elif response.status_code == 400: |
|
|
|
print(f"Bad Request (400): {response.text}") |
|
print("Check your request parameters and payload format.") |
|
return None |
|
elif response.status_code in [429, 504]: |
|
print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...") |
|
if attempt < max_retries: |
|
time.sleep(retry_delay) |
|
retry_delay *= 2 |
|
else: |
|
response.raise_for_status() |
|
else: |
|
print(f"Received unexpected status code {response.status_code}: {response.text}") |
|
response.raise_for_status() |
|
except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as timeout_error: |
|
print(f"Timeout occurred: {timeout_error}. Retrying in {retry_delay} seconds...") |
|
if attempt < max_retries: |
|
time.sleep(retry_delay) |
|
retry_delay *= 2 |
|
else: |
|
raise |
|
except requests.exceptions.RequestException as req_error: |
|
print(f"Request exception: {req_error}. Retrying in {retry_delay} seconds...") |
|
if attempt < max_retries: |
|
time.sleep(retry_delay) |
|
retry_delay *= 2 |
|
else: |
|
raise |
|
|
|
else: |
|
|
|
print("Max retries exceeded. Failed to generate image.") |
|
return None |
|
|
|
with NamedTemporaryFile(delete=False, suffix=".png") as tmp: |
|
image.save(tmp.name, format="PNG") |
|
constants.temp_files.append(tmp.name) |
|
print(f"Image saved to {tmp.name}") |
|
return tmp.name |
|
|
|
except Exception as e: |
|
print(f"Error generating AI image: {e}") |
|
return None |