HexaGrid / utils /ai_generator.py
Surn's picture
Change Torch references
ab4cf94
# utils/ai_generator.py
import gradio as gr
import os
import time
#from turtle import width # Added for implementing delays
from torch import cuda
import random
from utils.ai_generator_diffusers_flux import generate_ai_image_local
#from pathlib import Path
from huggingface_hub import InferenceClient
import requests
import io
from PIL import Image
from tempfile import NamedTemporaryFile
import utils.constants as constants
def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512, progress=gr.Progress(track_tqdm=True)):
# Initialize the InferenceClient
client = InferenceClient()
# Generate the image from the text
response = client(text, model_name)
# Get the image data
image_data = response.content
# Load the image from the data
image = Image.open(io.BytesIO(image_data))
# Resize the image
image = image.resize((image_width, image_height))
return image
def generate_ai_image(
map_option,
prompt_textbox_value,
neg_prompt_textbox_value,
model,
lora_weights=None,
conditioned_image=None,
pipeline = "FluxPipeline",
width=912,
height=512,
strength=0.5,
seed = 0,
progress=gr.Progress(track_tqdm=True),
*args,
**kwargs
):
if seed == 0:
seed = random.randint(0, constants.MAX_SEED)
if (cuda.is_available() and cuda.device_count() >= 1): # Check if a local GPU is available
print("Local GPU available. Generating image locally.")
if conditioned_image is not None:
pipeline = "FluxImg2ImgPipeline"
return generate_ai_image_local(
map_option,
prompt_textbox_value,
neg_prompt_textbox_value,
model,
lora_weights=lora_weights,
seed=seed,
conditioned_image=conditioned_image,
pipeline_name=pipeline,
strength=strength,
height=height,
width=width
)
else:
print("No local GPU available. Sending request to Hugging Face API.")
return generate_ai_image_remote(
map_option,
prompt_textbox_value,
neg_prompt_textbox_value,
model,
height=height,
width=width,
seed=seed
)
def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777,progress=gr.Progress(track_tqdm=True)):
max_retries = 3
retry_delay = 4 # Initial delay in seconds
try:
if map_option != "Prompt":
prompt = constants.PROMPTS[map_option]
# Convert the negative prompt string to a list
negative_prompt_str = constants.NEGATIVE_PROMPTS.get(map_option, "")
negative_prompt = [p.strip() for p in negative_prompt_str.split(',') if p.strip()]
else:
prompt = prompt_textbox_value
# Convert the negative prompt string to a list
negative_prompt = [p.strip() for p in neg_prompt_textbox_value.split(',') if p.strip()] if neg_prompt_textbox_value else []
print("Remotely Generating image with the following parameters:")
print(f"Prompt: {prompt}")
print(f"Negative Prompt: {negative_prompt}")
print(f"Height: {height}")
print(f"Width: {width}")
print(f"Number of Inference Steps: {num_inference_steps}")
print(f"Guidance Scale: {guidance_scale}")
print(f"Seed: {seed}")
for attempt in range(1, max_retries + 1):
try:
if os.getenv("IS_SHARED_SPACE") == "True":
client = InferenceClient(
model,
token=constants.HF_API_TOKEN
)
image = client.text_to_image(
inputs=prompt,
parameters={
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"width": width,
"height": height,
"max_sequence_length":512,
# Optional: Add 'scheduler' and 'seed' if needed
"seed": seed
}
)
else:
API_URL = f"https://api-inference.huggingface.co/models/{model}"
headers = {
"Authorization": f"Bearer {constants.HF_API_TOKEN}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"width": width,
"height": height,
"max_sequence_length":512,
# Optional: Add 'scheduler' and 'seed' if needed
"seed": seed
}
}
print(f"Attempt {attempt}: Sending POST request to Hugging Face API...")
response = requests.post(API_URL, headers=headers, json=payload, timeout=300) # Increased timeout to 30 seconds
if response.status_code == 200:
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
break # Exit the retry loop on success
elif response.status_code == 400:
# Handle 400 Bad Request specifically
print(f"Bad Request (400): {response.text}")
print("Check your request parameters and payload format.")
return None # Do not retry on 400 errors
elif response.status_code in [429, 504]:
print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...")
if attempt < max_retries:
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
response.raise_for_status() # Raise exception after max retries
else:
print(f"Received unexpected status code {response.status_code}: {response.text}")
response.raise_for_status()
except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as timeout_error:
print(f"Timeout occurred: {timeout_error}. Retrying in {retry_delay} seconds...")
if attempt < max_retries:
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
raise # Re-raise the exception after max retries
except requests.exceptions.RequestException as req_error:
print(f"Request exception: {req_error}. Retrying in {retry_delay} seconds...")
if attempt < max_retries:
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
raise # Re-raise the exception after max retries
else:
# If all retries failed
print("Max retries exceeded. Failed to generate image.")
return None
with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
image.save(tmp.name, format="PNG")
constants.temp_files.append(tmp.name)
print(f"Image saved to {tmp.name}")
return tmp.name
except Exception as e:
print(f"Error generating AI image: {e}")
return None