aliceblue11's picture
Update app.py
fe73fc1 verified
raw
history blame
16.1 kB
# 필요한 라이브러리 임포트
import gradio as gr
import random
import json
import os
import re
from datetime import datetime
from huggingface_hub import InferenceClient
import subprocess
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import random
import openai # OpenAI API 라이브러리 추가
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# OpenAI API 클라이언트 설정
openai.api_key = os.getenv("OPENAI_API_KEY")
# Initialize Florence model
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
# Florence caption function
def florence_caption(image):
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(image.width, image.height)
)
return parsed_answer["<MORE_DETAILED_CAPTION>"]
# JSON 파일 로드 함수
def load_json_file(file_name):
file_path = os.path.join("data", file_name)
with open(file_path, "r") as file:
return json.load(file)
ARTFORM = load_json_file("artform.json")
PHOTO_TYPE = load_json_file("photo_type.json")
BODY_TYPES = load_json_file("body_types.json")
DEFAULT_TAGS = load_json_file("default_tags.json")
ROLES = load_json_file("roles.json")
HAIRSTYLES = load_json_file("hairstyles.json")
ADDITIONAL_DETAILS = load_json_file("additional_details.json")
PHOTOGRAPHY_STYLES = load_json_file("photography_styles.json")
DEVICE = load_json_file("device.json")
PHOTOGRAPHER = load_json_file("photographer.json")
ARTIST = load_json_file("artist.json")
DIGITAL_ARTFORM = load_json_file("digital_artform.json")
PLACE = load_json_file("place.json")
LIGHTING = load_json_file("lighting.json")
CLOTHING = load_json_file("clothing.json")
COMPOSITION = load_json_file("composition.json")
POSE = load_json_file("pose.json")
BACKGROUND = load_json_file("background.json")
# PromptGenerator 클래스 정의
class PromptGenerator:
def __init__(self, seed=None):
self.rng = random.Random(seed)
def split_and_choose(self, input_str):
choices = [choice.strip() for choice in input_str.split(",")]
return self.rng.choices(choices, k=1)[0]
def get_choice(self, input_str, default_choices):
if input_str.lower() == "disabled":
return ""
elif "," in input_str:
return self.split_and_choose(input_str)
elif input_str.lower() == "random":
return self.rng.choices(default_choices, k=1)[0]
else:
return input_str
def clean_consecutive_commas(self, input_string):
cleaned_string = re.sub(r',\s*,', ',', input_string)
return cleaned_string
def process_string(self, replaced, seed):
replaced = re.sub(r'\s*,\s*', ',', replaced)
replaced = re.sub(r',+', ',', replaced)
original = replaced
first_break_clipl_index = replaced.find("BREAK_CLIPL")
second_break_clipl_index = replaced.find("BREAK_CLIPL", first_break_clipl_index + len("BREAK_CLIPL"))
if first_break_clipl_index != -1 and second_break_clipl_index != -1:
clip_content_l = replaced[first_break_clipl_index + len("BREAK_CLIPL"):second_break_clipl_index]
replaced = replaced[:first_break_clipl_index].strip(", ") + replaced[second_break_clipl_index + len("BREAK_CLIPL"):].strip(", ")
clip_l = clip_content_l
else:
clip_l = ""
first_break_clipg_index = replaced.find("BREAK_CLIPG")
second_break_clipg_index = replaced.find("BREAK_CLIPG", first_break_clipg_index + len("BREAK_CLIPG"))
if first_break_clipg_index != -1 and second_break_clipg_index != -1:
clip_content_g = replaced[first_break_clipg_index + len("BREAK_CLIPG"):second_break_clipg_index]
replaced = replaced[:first_break_clipg_index].strip(", ") + replaced[second_break_clipg_index + len("BREAK_CLIPG"):].strip(", ")
clip_g = clip_content_g
else:
clip_g = ""
t5xxl = replaced
original = original.replace("BREAK_CLIPL", "").replace("BREAK_CLIPG", "")
original = re.sub(r'\s*,\s*', ',', original)
original = re.sub(r',+', ',', original)
clip_l = re.sub(r'\s*,\s*', ',', clip_l)
clip_l = re.sub(r',+', ',', clip_l)
clip_g = re.sub(r'\s*,\s*', ',', clip_g)
clip_g = re.sub(r',+', ',', clip_g)
if clip_l.startswith(","):
clip_l = clip_l[1:]
if clip_g.startswith(","):
clip_g = clip_g[1:]
if original.startswith(","):
original = original[1:]
if t5xxl.startswith(","):
t5xxl = t5xxl[1:]
return original, seed, t5xxl, clip_l, clip_g
def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
additional_details, photography_styles, device, photographer, artist, digital_artform,
place, lighting, clothing, composition, pose, background, input_image):
# 생략된 기능들...
pass
def add_caption_to_prompt(self, prompt, caption):
if caption:
return f"{prompt}, {caption}"
return prompt
# HuggingFace 모델을 사용한 텍스트 생성 클래스 정의
class HuggingFaceInferenceNode:
def __init__(self):
self.clients = {
"Mixtral": InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"),
"Mistral": InferenceClient("mistralai/Mistral-7B-Instruct-v0.3"),
"Llama 3": InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct"),
"Mistral-Nemo": InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
}
self.prompts_dir = "./prompts"
os.makedirs(self.prompts_dir, exist_ok=True)
def save_prompt(self, prompt):
filename_text = "hf_" + prompt.split(',')[0].strip()
filename_text = re.sub(r'[^\w\-_\. ]', '_', filename_text)
filename_text = filename_text[:30]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_filename = f"{filename_text}_{timestamp}.txt"
filename = os.path.join(self.prompts_dir, base_filename)
with open(filename, "w") as file:
file.write(prompt)
print(f"Prompt saved to {filename}")
def generate(self, model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt=""):
# 생략된 기능들...
pass
# gpt-4o-mini와 Cohere Command R+를 사용한 프롬프트 생성 함수
def call_gpt4o_mini(content, system_message, max_tokens=1000, temperature=0.7, top_p=1):
response = openai.ChatCompletion.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": content},
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return response.choices[0].message['content']
def call_cohere(content, temperature=0.7, max_tokens=1000):
response = openai.ChatCompletion.create(
model="Cohere-Command-R+",
messages=[
{"role": "user", "content": content},
],
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].message['content']
# Gradio 인터페이스 생성 함수
def create_interface():
prompt_generator = PromptGenerator() # PromptGenerator 클래스가 정의되었으므로 사용 가능
huggingface_node = HuggingFaceInferenceNode()
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1>
<p><center>이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</center></p>""")
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("기본 설정"):
seed = gr.Number(label="시드", value=random.randint(0, 1000000))
custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)")
subject = gr.Textbox(label="주제 (선택사항)")
global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화")
with gr.Accordion("예술 형식 및 사진 유형", open=False):
artform = gr.Dropdown(["비활성화", "랜덤"] + ARTFORM, label="예술 형식", value="비활성화")
photo_type = gr.Dropdown(["비활성화", "랜덤"] + PHOTO_TYPE, label="사진 유형", value="비활성화")
with gr.Accordion("캐릭터 세부사항", open=False):
body_types = gr.Dropdown(["비활성화", "랜덤"] + BODY_TYPES, label="체형", value="비활성화")
default_tags = gr.Dropdown(["비활성화", "랜덤"] + DEFAULT_TAGS, label="기본 태그", value="비활성화")
roles = gr.Dropdown(["비활성화", "랜덤"] + ROLES, label="역할", value="비활성화")
hairstyles = gr.Dropdown(["비활성화", "랜덤"] + HAIRSTYLES, label="헤어스타일", value="비활성화")
clothing = gr.Dropdown(["비활성화", "랜덤"] + CLOTHING, label="의상", value="비활성화")
with gr.Accordion("장면 세부사항", open=False):
place = gr.Dropdown(["비활성화", "랜덤"] + PLACE, label="장소", value="비활성화")
lighting = gr.Dropdown(["비활성화", "랜덤"] + LIGHTING, label="조명", value="비활성화")
composition = gr.Dropdown(["비활성화", "랜덤"] + COMPOSITION, label="구성", value="비활성화")
pose = gr.Dropdown(["비활성화", "랜덤"] + POSE, label="포즈", value="비활성화")
background = gr.Dropdown(["비활성화", "랜덤"] + BACKGROUND, label="배경", value="비활성화")
with gr.Accordion("스타일 및 아티스트", open=False):
additional_details = gr.Dropdown(["비활성화", "랜덤"] + ADDITIONAL_DETAILS, label="추가 세부 사항", value="비활성화")
photography_styles = gr.Dropdown(["비활성화", "랜덤"] + PHOTOGRAPHY_STYLES, label="사진 스타일", value="비활성화")
device = gr.Dropdown(["비활성화", "랜덤"] + DEVICE, label="장비", value="비활성화")
photographer = gr.Dropdown(["비활성화", "랜덤"] + PHOTOGRAPHER, label="사진작가", value="비활성화")
artist = gr.Dropdown(["비활성화", "랜덤"] + ARTIST, label="아티스트", value="비활성화")
digital_artform = gr.Dropdown(["비활성화", "랜덤"] + DIGITAL_ARTFORM, label="디지털 예술 형식", value="비활성화")
generate_button = gr.Button("프롬프트 생성")
with gr.Column(scale=2):
with gr.Accordion("이미지 및 설명", open=False):
input_image = gr.Image(label="입력 이미지 (선택사항)")
caption_output = gr.Textbox(label="생성된 설명", lines=3)
create_caption_button = gr.Button("설명 생성")
add_caption_button = gr.Button("프롬프트에 설명 추가")
with gr.Accordion("프롬프트 생성", open=True):
output = gr.Textbox(label="생성된 프롬프트 / 입력 텍스트", lines=4)
t5xxl_output = gr.Textbox(label="T5XXL 출력", visible=True)
clip_l_output = gr.Textbox(label="CLIP L 출력", visible=True)
clip_g_output = gr.Textbox(label="CLIP G 출력", visible=True)
with gr.Column(scale=2):
with gr.Accordion("LLM을 사용한 프롬프트 생성", open=False):
model = gr.Dropdown(["Mixtral", "Mistral", "Llama 3", "Mistral-Nemo", "gpt-4o-mini", "Cohere-Command-R+"], label="모델", value="Llama 3")
happy_talk = gr.Checkbox(label="행복한 대화", value=True)
compress = gr.Checkbox(label="압축", value=True)
compression_level = gr.Radio(["부드럽게", "중간", "강하게"], label="압축 레벨", value="강하게")
poster = gr.Checkbox(label="포스터 형식", value=False)
custom_base_prompt = gr.Textbox(label="사용자 정의 기본 프롬프트", lines=5)
generate_text_button = gr.Button("LLM으로 프롬프트 생성")
text_output = gr.Textbox(label="생성된 텍스트", lines=10)
def create_caption(image):
if image is not None:
return florence_caption(image)
return ""
create_caption_button.click(
create_caption,
inputs=[input_image],
outputs=[caption_output]
)
generate_button.click(
prompt_generator.generate_prompt,
inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
additional_details, photography_styles, device, photographer, artist, digital_artform,
place, lighting, clothing, composition, pose, background],
outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output]
)
add_caption_button.click(
prompt_generator.add_caption_to_prompt,
inputs=[output, caption_output],
outputs=[output]
)
generate_text_button.click(
lambda model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt: call_gpt4o_mini(input_text, custom_base_prompt) if model == "gpt-4o-mini" else call_cohere(input_text),
inputs=[model, output, happy_talk, compress, compression_level, poster, custom_base_prompt],
outputs=text_output
)
def update_all_options(choice):
return {dropdown: gr.update(value=choice) for dropdown in [
artform, photo_type, body_types, default_tags, roles, hairstyles, clothing,
place, lighting, composition, pose, background, additional_details,
photography_styles, device, photographer, artist, digital_artform
]}
global_option.change(
update_all_options,
inputs=[global_option],
outputs=[
artform, photo_type, body_types, default_tags, roles, hairstyles, clothing,
place, lighting, composition, pose, background, additional_details,
photography_styles, device, photographer, artist, digital_artform
]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()