imagen-tutorial / app.py
heaversm's picture
add max generations functionality and do enforce pw functionality
81e44ac
# file stuff
import os
from io import BytesIO
#image generation stuff
from PIL import Image
# gradio / hf / image gen stuff
import gradio as gr
from dotenv import load_dotenv
from google.cloud import aiplatform
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel
from vertexai import preview
# GCP credentials stuff
import json
import pybase64
from google.oauth2 import service_account
import google.auth
load_dotenv()
service_account_json = pybase64.b64decode(os.getenv("IMAGEN"))
service_account_info = json.loads(service_account_json)
credentials = service_account.Credentials.from_service_account_info(service_account_info)
project="pdr-imagen"
aiplatform.init(project=project, credentials=credentials)
# enforce password is True if DO_ENFORCE_PW is set to "true"
DO_ENFORCE_PW = os.getenv("DO_ENFORCE_PW")
def trigger_max_gens():
gr.Warning("🖼️ Max Image Generations Reached! 🖼️")
def generate_image(pw,prompt,model_name):
if pw != os.getenv("PW") and DO_ENFORCE_PW == "true":
raise gr.Error("Invalid password. Please try again.")
try:
model = ImageGenerationModel.from_pretrained(model_name)
response = model.generate_images(
prompt=prompt,
number_of_images=1,
)
image_bytes = response[0]._image_bytes
image_url = Image.open(BytesIO(image_bytes))
except Exception as e:
print(e)
raise gr.Error(f"An error occurred while generating the image")
return image_url
custom_js = """
function customJS() {
//Limit Image Generation
const MAX_GENERATIONS = 10;
const DO_ENFORCE_MAX_GENERATIONS = true;
disableGenerateButton = function() {
const btn = document.getElementById('btn_generate-images');
btn.disabled = true;
btn.classList.add('not-visible');
}
triggerMaxGenerationsToast = function() {
const trigger_max_gens_btn = document.getElementById('trigger-max-gens-btn');
trigger_max_gens_btn.click();
}
setCurrentGenerations = function() {
if (!DO_ENFORCE_MAX_GENERATIONS) {
return;
}
const curGenerations = localStorage.getItem('currentGenerations');
console.log(`${curGenerations} / ${MAX_GENERATIONS}`)
if (curGenerations) {
if (curGenerations >= MAX_GENERATIONS) {
triggerMaxGenerationsToast();
disableGenerateButton();
} else {
localStorage.setItem('currentGenerations', parseInt(curGenerations) + 1);
}
} else {
localStorage.setItem('currentGenerations', 1);
}
}
setCurrentGenerations();
document.getElementById('btn_generate-images').addEventListener('click', function() {
setCurrentGenerations();
});
}
"""
with gr.Blocks(js=custom_js) as demo:
gr.Markdown("# <center>Google Vertex Imagen Generator</center>")
#password
pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service",visible=False if DO_ENFORCE_PW == "false" else True)
gr.Markdown("Need access? Send a DM to @HeaversMike on Twitter or send me an email / Slack msg.",visible=False if DO_ENFORCE_PW == "false" else True)
#instructions
with gr.Accordion("Instructions & Tips",label="instructions",open=False):
with gr.Row():
gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ")
#prompts
with gr.Accordion("Prompt",label="Prompt",open=True):
text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
with gr.Row():
btn = gr.Button("Generate Images", variant="primary", elem_id="btn_generate-images")
#output
with gr.Accordion("Image Output",label="Image Output",open=True):
output_image = gr.Image(label="Image")
with gr.Row():
trigger_max_gens_btn = gr.Button(value="Show Max Gens Reached",visible=False,elem_id="trigger-max-gens-btn")
btn.click(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name=False)
text.submit(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF
#js-triggered functionality
trigger_max_gens_btn.click(trigger_max_gens, None, None)
demo.launch(share=False)