magnopus-demo / app.py
eshaan's picture
Update app.py
9f25063
# This demo needs to be run from the repo folder.
# python demo/fake_gan/run.py
import replicate
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
import aiohttp
import asyncio
async def get_image_from_url(api_url, payload, file_path):
async with aiohttp.ClientSession() as session:
print("making api request to", api_url)
async with session.post(api_url, json=payload) as resp:
data = await resp.read()
with open(file_path, 'wb') as f:
f.write(data)
async def fake_gan_async(text_input):
# SD 2-1
sd_response = replicate.run(
"stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
input={"prompt": text_input}
)
sd_image_url = sd_response[0]
# SD XL
sdxl_api_url = "http://3bfb-34-142-187-46.ngrok-free.app/get_image"
payload_sdxl = {"prompt": text_input,}
sdxl_filepath = 'sdxl.png'
# SD LoRA
lora_api_url = "http://841a-35-240-138-5.ngrok-free.app/get_image"
lora_text_input = text_input.replace("baby yoda", "gr1gu")
lora_text_input = lora_text_input.replace("grogu", "gr1gu")
payload_lora = {"prompt": lora_text_input,
"lora_weights_path": "/content/drive/MyDrive/kohya-trainer/output/grogu_lora_v4-000030.safetensors"
}
lora_filepath = 'grogu_sdxl_lora.png'
# SD XL Fintuned, CI
finetuned_ci_url = "http://f20b-34-125-36-4.ngrok-free.app/get_image"
payload_finetuned_ci = {"prompt": text_input,
"lora_weights_path": "/content/drive/MyDrive/kohya-trainer/output/sdxl_20_epochs_ci.safetensors"
}
finetuned_ci_filepath = 'sdxl_finetuned_ci.png'
# make requests concurrently
await asyncio.gather(
get_image_from_url(sdxl_api_url, payload_sdxl, sdxl_filepath),
get_image_from_url(lora_api_url, payload_lora, lora_filepath),
get_image_from_url(finetuned_ci_url, payload_finetuned_ci, finetuned_ci_filepath)
)
return [
sd_image_url,
sdxl_filepath,
lora_filepath,
# "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
# "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
]
# Running the async function:
# asyncio.run(fake_gan_async('your_text_input'))
title = "The Mandalorian Vango Demo"
description = "We experimented with a few models to make this demo! Here we compare A) Stable Diffusion 2-1, B) Stable Diffusion XL, C) finetuned SD XL on stills and automated labels from The Mandolorian and D) LoRA finetuned SD XL on 20 hand written labels of Baby Yoda"
# article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center></p>"
demo = gr.Interface(
fn=lambda x: asyncio.run(fake_gan_async(x)),
inputs=gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
container=False,
),
outputs=[gr.outputs.Image(type="pil", label="Stable Diffusion 2-1"),
gr.outputs.Image(type="pil", label="Stable Diffusion XL"),
gr.outputs.Image(type="pil", label="SD XL, Grogu LoRA on 20 hand written labels"),
# gr.outputs.Image(type="pil", label='SD XL, finetuned on 2,000 "shitty" CI labels'),
# gr.outputs.Image(type="pil", label='SD XL, finetuned on 2,000 "shitty" BLIP labels'),
],
title=title,
description=description,)
# article=article
# examples=examples)
demo.launch()
if __name__ == "__main__":
demo.launch()