NovelAI-V3 / app.py
nyanko7's picture
Update app.py
10817d9
raw
history blame
5.1 kB
import datetime
import gradio as gr
import requests
import random
import io
import zipfile
from PIL import Image
import os
import numpy as np
import json
import boto3
# Create an S3 client
s3 = boto3.client('s3')
def save_to_s3(image_data, payload, file_name):
# Define the bucket and the path
bucket_name = 'dataset-novelai'
folder_name = datetime.datetime.now().strftime("%Y-%m-%d")
image_key = f'gradio/{folder_name}/{file_name}.webp'
payload_key = f'gradio/{folder_name}/{file_name}.json'
# Save the image
image_data.seek(0) # Go to the start of the BytesIO object
s3.upload_fileobj(image_data, bucket_name, image_key, ExtraArgs={'ContentType': 'image/webp'})
# Save the payload
payload_data = io.BytesIO(payload.encode('utf-8'))
s3.upload_fileobj(payload_data, bucket_name, payload_key, ExtraArgs={'ContentType': 'application/json'})
# Function to handle the NovelAI API request
def generate_novelai_image(input_text, quality_tags, seed, negative_prompt, scale, ratio, sampler):
jwt_token = os.environ.get('NAI_API_KEY')
if ratio == "Landscape (1216x832)":
width = 1216
height = 832
elif ratio == "Square (1024x1024)":
width = 1024
height = 1024
elif ratio == "Portrait (832x1216)":
width = 832
height = 1216
# Check if quality tags are provided and append to input
final_input = input_text
if quality_tags:
final_input += ", " + quality_tags
# Assign a random seed if seed is -1
if seed == -1:
seed = random.randint(0, 2**32 - 1)
# Define the API URL
url = "https://api.novelai.net/ai/generate-image"
# Set the headers
headers = {
"Authorization": f"Bearer {jwt_token}",
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/"
}
# Define the payload
payload = {
"action": "generate",
"input": final_input,
"model": "nai-diffusion-3",
"parameters": {
"width": width,
"height": height,
"scale": scale,
"sampler": sampler,
"steps": 28,
"n_samples": 1,
"ucPreset": 0,
"add_original_image": False,
"cfg_rescale": 0,
"controlnet_strength": 1,
"dynamic_thresholding": False,
"legacy": False,
"negative_prompt": negative_prompt,
"noise_schedule": "native",
"qualityToggle": True,
"seed": seed,
"sm": False,
"sm_dyn": False,
"ucPreset": 0,
"uncond_scale": 1,
}
}
# Send the POST request
response = requests.post(url, json=payload, headers=headers)
# Process the response
if response.headers.get('Content-Type') == 'application/x-zip-compressed':
zipfile_in_memory = io.BytesIO(response.content)
with zipfile.ZipFile(zipfile_in_memory, 'r') as zip_ref:
file_names = zip_ref.namelist()
if file_names:
with zip_ref.open(file_names[0]) as file:
image = Image.open(file)
# Prepare to save the image to S3
buffered = io.BytesIO()
image.save(buffered, format="WEBP", quality=98)
file_name = str(int(datetime.datetime.now().timestamp()))
save_to_s3(buffered, json.dumps(payload, indent=4), file_name)
return np.array(image), json.dumps(payload, indent=4)
else:
return "No images found in the zip file.", json.dumps(payload, indent=4)
else:
return "The response is not a zip file.", json.dumps(payload, indent=4)
# Create Gradio interface
iface = gr.Interface(
fn=generate_novelai_image,
inputs=[
gr.Textbox(label="Input Text", lines=3),
gr.Textbox(label="Quality Tags", value="best quality, amazing quality, very aesthetic, absurdres"),
gr.Slider(minimum=-1, maximum=2**32 - 1, step=1, value=-1, label="Seed"),
gr.Textbox(label="Negative Prompt", value="nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"),
gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Scale"),
gr.Radio(choices=["Landscape (1216x832)", "Square (1024x1024)", "Portrait (832x1216)"], value="Portrait (832x1216)"),
gr.Dropdown(
choices=[
"k_euler", "k_euler_ancestral", "k_dpmpp_2s_ancestral",
"k_dpmpp_2m", "k_dpmpp_sde", "ddim_v3"
],
value="k_euler",
label="Sampler"
)
],
outputs=[
"image",
gr.Code(label="Submitted Payload", language="json")
],
concurrency_limit=7,
)
try:
iface.launch(share=True)
except RuntimeError: # use in HF spaces
iface.launch()