Spaces:
Runtime error
Runtime error
# 필요한 라이브러리 임포트 | |
import gradio as gr | |
import random | |
import json | |
import os | |
import re | |
from datetime import datetime | |
from huggingface_hub import InferenceClient | |
import subprocess | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import random | |
import openai # OpenAI API 라이브러리 추가 | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# OpenAI API 클라이언트 설정 | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Initialize Florence model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval() | |
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True) | |
# Florence caption function | |
def florence_caption(image): | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device) | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = florence_processor.post_process_generation( | |
generated_text, | |
task="<MORE_DETAILED_CAPTION>", | |
image_size=(image.width, image.height) | |
) | |
return parsed_answer["<MORE_DETAILED_CAPTION>"] | |
# JSON 파일 로드 함수 | |
def load_json_file(file_name): | |
file_path = os.path.join("data", file_name) | |
with open(file_path, "r") as file: | |
return json.load(file) | |
ARTFORM = load_json_file("artform.json") | |
PHOTO_TYPE = load_json_file("photo_type.json") | |
BODY_TYPES = load_json_file("body_types.json") | |
DEFAULT_TAGS = load_json_file("default_tags.json") | |
ROLES = load_json_file("roles.json") | |
HAIRSTYLES = load_json_file("hairstyles.json") | |
ADDITIONAL_DETAILS = load_json_file("additional_details.json") | |
PHOTOGRAPHY_STYLES = load_json_file("photography_styles.json") | |
DEVICE = load_json_file("device.json") | |
PHOTOGRAPHER = load_json_file("photographer.json") | |
ARTIST = load_json_file("artist.json") | |
DIGITAL_ARTFORM = load_json_file("digital_artform.json") | |
PLACE = load_json_file("place.json") | |
LIGHTING = load_json_file("lighting.json") | |
CLOTHING = load_json_file("clothing.json") | |
COMPOSITION = load_json_file("composition.json") | |
POSE = load_json_file("pose.json") | |
BACKGROUND = load_json_file("background.json") | |
# PromptGenerator 클래스 정의 | |
class PromptGenerator: | |
def __init__(self, seed=None): | |
self.rng = random.Random(seed) | |
def split_and_choose(self, input_str): | |
choices = [choice.strip() for choice in input_str.split(",")] | |
return self.rng.choices(choices, k=1)[0] | |
def get_choice(self, input_str, default_choices): | |
if input_str.lower() == "disabled": | |
return "" | |
elif "," in input_str: | |
return self.split_and_choose(input_str) | |
elif input_str.lower() == "random": | |
return self.rng.choices(default_choices, k=1)[0] | |
else: | |
return input_str | |
def clean_consecutive_commas(self, input_string): | |
cleaned_string = re.sub(r',\s*,', ',', input_string) | |
return cleaned_string | |
def process_string(self, replaced, seed): | |
replaced = re.sub(r'\s*,\s*', ',', replaced) | |
replaced = re.sub(r',+', ',', replaced) | |
original = replaced | |
first_break_clipl_index = replaced.find("BREAK_CLIPL") | |
second_break_clipl_index = replaced.find("BREAK_CLIPL", first_break_clipl_index + len("BREAK_CLIPL")) | |
if first_break_clipl_index != -1 and second_break_clipl_index != -1: | |
clip_content_l = replaced[first_break_clipl_index + len("BREAK_CLIPL"):second_break_clipl_index] | |
replaced = replaced[:first_break_clipl_index].strip(", ") + replaced[second_break_clipl_index + len("BREAK_CLIPL"):].strip(", ") | |
clip_l = clip_content_l | |
else: | |
clip_l = "" | |
first_break_clipg_index = replaced.find("BREAK_CLIPG") | |
second_break_clipg_index = replaced.find("BREAK_CLIPG", first_break_clipg_index + len("BREAK_CLIPG")) | |
if first_break_clipg_index != -1 and second_break_clipg_index != -1: | |
clip_content_g = replaced[first_break_clipg_index + len("BREAK_CLIPG"):second_break_clipg_index] | |
replaced = replaced[:first_break_clipg_index].strip(", ") + replaced[second_break_clipg_index + len("BREAK_CLIPG"):].strip(", ") | |
clip_g = clip_content_g | |
else: | |
clip_g = "" | |
t5xxl = replaced | |
original = original.replace("BREAK_CLIPL", "").replace("BREAK_CLIPG", "") | |
original = re.sub(r'\s*,\s*', ',', original) | |
original = re.sub(r',+', ',', original) | |
clip_l = re.sub(r'\s*,\s*', ',', clip_l) | |
clip_l = re.sub(r',+', ',', clip_l) | |
clip_g = re.sub(r'\s*,\s*', ',', clip_g) | |
clip_g = re.sub(r',+', ',', clip_g) | |
if clip_l.startswith(","): | |
clip_l = clip_l[1:] | |
if clip_g.startswith(","): | |
clip_g = clip_g[1:] | |
if original.startswith(","): | |
original = original[1:] | |
if t5xxl.startswith(","): | |
t5xxl = t5xxl[1:] | |
return original, seed, t5xxl, clip_l, clip_g | |
def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles, | |
additional_details, photography_styles, device, photographer, artist, digital_artform, | |
place, lighting, clothing, composition, pose, background, input_image): | |
# 생략된 기능들... | |
pass | |
def add_caption_to_prompt(self, prompt, caption): | |
if caption: | |
return f"{prompt}, {caption}" | |
return prompt | |
# HuggingFace 모델을 사용한 텍스트 생성 클래스 정의 | |
class HuggingFaceInferenceNode: | |
def __init__(self): | |
self.clients = { | |
"Mixtral": InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"), | |
"Mistral": InferenceClient("mistralai/Mistral-7B-Instruct-v0.3"), | |
"Llama 3": InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct"), | |
"Mistral-Nemo": InferenceClient("mistralai/Mistral-Nemo-Instruct-2407") | |
} | |
self.prompts_dir = "./prompts" | |
os.makedirs(self.prompts_dir, exist_ok=True) | |
def save_prompt(self, prompt): | |
filename_text = "hf_" + prompt.split(',')[0].strip() | |
filename_text = re.sub(r'[^\w\-_\. ]', '_', filename_text) | |
filename_text = filename_text[:30] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
base_filename = f"{filename_text}_{timestamp}.txt" | |
filename = os.path.join(self.prompts_dir, base_filename) | |
with open(filename, "w") as file: | |
file.write(prompt) | |
print(f"Prompt saved to {filename}") | |
def generate(self, model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt=""): | |
# 생략된 기능들... | |
pass | |
# gpt-4o-mini와 Cohere Command R+를 사용한 프롬프트 생성 함수 | |
def call_gpt4o_mini(content, system_message, max_tokens=1000, temperature=0.7, top_p=1): | |
response = openai.ChatCompletion.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": content}, | |
], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
return response.choices[0].message['content'] | |
def call_cohere(content, temperature=0.7, max_tokens=1000): | |
response = openai.ChatCompletion.create( | |
model="Cohere-Command-R+", | |
messages=[ | |
{"role": "user", "content": content}, | |
], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
) | |
return response.choices[0].message['content'] | |
# Gradio 인터페이스 생성 함수 | |
def create_interface(): | |
prompt_generator = PromptGenerator() # PromptGenerator 클래스가 정의되었으므로 사용 가능 | |
huggingface_node = HuggingFaceInferenceNode() | |
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo: | |
gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1> | |
<p><center>이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</center></p>""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Accordion("기본 설정"): | |
seed = gr.Number(label="시드", value=random.randint(0, 1000000)) | |
custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)") | |
subject = gr.Textbox(label="주제 (선택사항)") | |
global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화") | |
with gr.Accordion("예술 형식 및 사진 유형", open=False): | |
artform = gr.Dropdown(["비활성화", "랜덤"] + ARTFORM, label="예술 형식", value="비활성화") | |
photo_type = gr.Dropdown(["비활성화", "랜덤"] + PHOTO_TYPE, label="사진 유형", value="비활성화") | |
with gr.Accordion("캐릭터 세부사항", open=False): | |
body_types = gr.Dropdown(["비활성화", "랜덤"] + BODY_TYPES, label="체형", value="비활성화") | |
default_tags = gr.Dropdown(["비활성화", "랜덤"] + DEFAULT_TAGS, label="기본 태그", value="비활성화") | |
roles = gr.Dropdown(["비활성화", "랜덤"] + ROLES, label="역할", value="비활성화") | |
hairstyles = gr.Dropdown(["비활성화", "랜덤"] + HAIRSTYLES, label="헤어스타일", value="비활성화") | |
clothing = gr.Dropdown(["비활성화", "랜덤"] + CLOTHING, label="의상", value="비활성화") | |
with gr.Accordion("장면 세부사항", open=False): | |
place = gr.Dropdown(["비활성화", "랜덤"] + PLACE, label="장소", value="비활성화") | |
lighting = gr.Dropdown(["비활성화", "랜덤"] + LIGHTING, label="조명", value="비활성화") | |
composition = gr.Dropdown(["비활성화", "랜덤"] + COMPOSITION, label="구성", value="비활성화") | |
pose = gr.Dropdown(["비활성화", "랜덤"] + POSE, label="포즈", value="비활성화") | |
background = gr.Dropdown(["비활성화", "랜덤"] + BACKGROUND, label="배경", value="비활성화") | |
with gr.Accordion("스타일 및 아티스트", open=False): | |
additional_details = gr.Dropdown(["비활성화", "랜덤"] + ADDITIONAL_DETAILS, label="추가 세부 사항", value="비활성화") | |
photography_styles = gr.Dropdown(["비활성화", "랜덤"] + PHOTOGRAPHY_STYLES, label="사진 스타일", value="비활성화") | |
device = gr.Dropdown(["비활성화", "랜덤"] + DEVICE, label="장비", value="비활성화") | |
photographer = gr.Dropdown(["비활성화", "랜덤"] + PHOTOGRAPHER, label="사진작가", value="비활성화") | |
artist = gr.Dropdown(["비활성화", "랜덤"] + ARTIST, label="아티스트", value="비활성화") | |
digital_artform = gr.Dropdown(["비활성화", "랜덤"] + DIGITAL_ARTFORM, label="디지털 예술 형식", value="비활성화") | |
generate_button = gr.Button("프롬프트 생성") | |
with gr.Column(scale=2): | |
with gr.Accordion("이미지 및 설명", open=False): | |
input_image = gr.Image(label="입력 이미지 (선택사항)") | |
caption_output = gr.Textbox(label="생성된 설명", lines=3) | |
create_caption_button = gr.Button("설명 생성") | |
add_caption_button = gr.Button("프롬프트에 설명 추가") | |
with gr.Accordion("프롬프트 생성", open=True): | |
output = gr.Textbox(label="생성된 프롬프트 / 입력 텍스트", lines=4) | |
t5xxl_output = gr.Textbox(label="T5XXL 출력", visible=True) | |
clip_l_output = gr.Textbox(label="CLIP L 출력", visible=True) | |
clip_g_output = gr.Textbox(label="CLIP G 출력", visible=True) | |
with gr.Column(scale=2): | |
with gr.Accordion("LLM을 사용한 프롬프트 생성", open=False): | |
model = gr.Dropdown(["Mixtral", "Mistral", "Llama 3", "Mistral-Nemo", "gpt-4o-mini", "Cohere-Command-R+"], label="모델", value="Llama 3") | |
happy_talk = gr.Checkbox(label="행복한 대화", value=True) | |
compress = gr.Checkbox(label="압축", value=True) | |
compression_level = gr.Radio(["부드럽게", "중간", "강하게"], label="압축 레벨", value="강하게") | |
poster = gr.Checkbox(label="포스터 형식", value=False) | |
custom_base_prompt = gr.Textbox(label="사용자 정의 기본 프롬프트", lines=5) | |
generate_text_button = gr.Button("LLM으로 프롬프트 생성") | |
text_output = gr.Textbox(label="생성된 텍스트", lines=10) | |
def create_caption(image): | |
if image is not None: | |
return florence_caption(image) | |
return "" | |
create_caption_button.click( | |
create_caption, | |
inputs=[input_image], | |
outputs=[caption_output] | |
) | |
generate_button.click( | |
prompt_generator.generate_prompt, | |
inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles, | |
additional_details, photography_styles, device, photographer, artist, digital_artform, | |
place, lighting, clothing, composition, pose, background], | |
outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output] | |
) | |
add_caption_button.click( | |
prompt_generator.add_caption_to_prompt, | |
inputs=[output, caption_output], | |
outputs=[output] | |
) | |
generate_text_button.click( | |
lambda model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt: call_gpt4o_mini(input_text, custom_base_prompt) if model == "gpt-4o-mini" else call_cohere(input_text), | |
inputs=[model, output, happy_talk, compress, compression_level, poster, custom_base_prompt], | |
outputs=text_output | |
) | |
def update_all_options(choice): | |
return {dropdown: gr.update(value=choice) for dropdown in [ | |
artform, photo_type, body_types, default_tags, roles, hairstyles, clothing, | |
place, lighting, composition, pose, background, additional_details, | |
photography_styles, device, photographer, artist, digital_artform | |
]} | |
global_option.change( | |
update_all_options, | |
inputs=[global_option], | |
outputs=[ | |
artform, photo_type, body_types, default_tags, roles, hairstyles, clothing, | |
place, lighting, composition, pose, background, additional_details, | |
photography_styles, device, photographer, artist, digital_artform | |
] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |