gokaygokay commited on
Commit
13cefbc
·
verified ·
1 Parent(s): 191860e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
 
5
  import re
6
  import random
7
  import os
@@ -40,12 +41,35 @@ kolors_pipe.enable_model_cpu_offload()
40
  vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner-v2").to(device).eval()
41
  vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner-v2")
42
 
 
 
 
 
43
  # Prompt Enhancer
44
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
45
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
46
 
47
  MAX_SEED = 2**32 - 1
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # VLM Captioner function
50
  def create_captions_rich(image):
51
  prompt = "caption en"
@@ -112,9 +136,12 @@ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height,
112
 
113
  # Gradio Interface
114
  @spaces.GPU
115
- def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
116
  if use_vlm and image is not None:
117
- prompt = create_captions_rich(image)
 
 
 
118
  else:
119
  prompt = text_prompt
120
 
@@ -161,6 +188,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondar
161
  with gr.Group(elem_classes="input-group"):
162
  input_image = gr.Image(label="Input Image for VLM")
163
  use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
 
164
 
165
  with gr.Group(elem_classes="input-group"):
166
  text_prompt = gr.Textbox(label="Text Prompt")
@@ -187,7 +215,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondar
187
  generate_btn.click(
188
  fn=process_workflow,
189
  inputs=[
190
- input_image, text_prompt, use_vlm, use_enhancer, model_choice,
191
  negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
192
  ],
193
  outputs=[output_image, final_prompt, used_seed]
 
2
  import gradio as gr
3
  import torch
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
+ from transformers import AutoProcessor, AutoModelForCausalLM
6
  import re
7
  import random
8
  import os
 
41
  vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner-v2").to(device).eval()
42
  vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner-v2")
43
 
44
+ # Initialize Florence model
45
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
46
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
47
+
48
  # Prompt Enhancer
49
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
50
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
51
 
52
  MAX_SEED = 2**32 - 1
53
 
54
+ # Florence caption function
55
+ def florence_caption(image):
56
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
57
+ generated_ids = florence_model.generate(
58
+ input_ids=inputs["input_ids"],
59
+ pixel_values=inputs["pixel_values"],
60
+ max_new_tokens=1024,
61
+ early_stopping=False,
62
+ do_sample=False,
63
+ num_beams=3,
64
+ )
65
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
66
+ parsed_answer = florence_processor.post_process_generation(
67
+ generated_text,
68
+ task="<MORE_DETAILED_CAPTION>",
69
+ image_size=(image.width, image.height)
70
+ )
71
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
72
+
73
  # VLM Captioner function
74
  def create_captions_rich(image):
75
  prompt = "caption en"
 
136
 
137
  # Gradio Interface
138
  @spaces.GPU
139
+ def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, vlm_model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
140
  if use_vlm and image is not None:
141
+ if vlm_model_choice == "Long Captioner":
142
+ prompt = create_captions_rich(image)
143
+ else: # Florence
144
+ prompt = florence_caption(image)
145
  else:
146
  prompt = text_prompt
147
 
 
188
  with gr.Group(elem_classes="input-group"):
189
  input_image = gr.Image(label="Input Image for VLM")
190
  use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
191
+ vlm_model_choice = gr.Radio(["Long Captioner", "Florence"], label="VLM Model", value="Long Captioner")
192
 
193
  with gr.Group(elem_classes="input-group"):
194
  text_prompt = gr.Textbox(label="Text Prompt")
 
215
  generate_btn.click(
216
  fn=process_workflow,
217
  inputs=[
218
+ input_image, text_prompt, use_vlm, use_enhancer, model_choice, vlm_model_choice,
219
  negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
220
  ],
221
  outputs=[output_image, final_prompt, used_seed]