openfree's picture
Update app.py
7686a04 verified
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from PIL import Image
import numpy as np
import os
import time
import spaces
# Load model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
vl_gpt = vl_gpt.to(torch.float16)
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@torch.inference_mode()
@spaces.GPU(duration=120)
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
conversation = [
{
"role": "<|User|>",
"content": f"<image_placeholder>\n{question}",
"images": [image],
},
{"role": "<|Assistant|>", "content": ""},
]
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=4000,
do_sample=False if temperature == 0 else True,
use_cache=True,
temperature=temperature,
top_p=top_p,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
def generate(input_ids,
width,
height,
temperature: float = 1,
parallel_size: int = 5,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(image_token_num_per_image):
with torch.no_grad():
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size])
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=5):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
def generate_image(prompt,
seed=None,
guidance=5,
t2i_temperature=1.0):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
# Set the seed for reproducible results
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
width = 384
height = 384
parallel_size = 5
with torch.no_grad():
messages = [{'role': '<|User|>', 'content': prompt},
{'role': '<|Assistant|>', 'content': ''}]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
sft_format=vl_chat_processor.sft_format,
system_prompt='')
text = text + vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
output, patches = generate(input_ids,
width // 16 * 16,
height // 16 * 16,
cfg_weight=guidance,
parallel_size=parallel_size,
temperature=t2i_temperature)
images = unpack(patches,
width // 16 * 16,
height // 16 * 16,
parallel_size=parallel_size)
return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
# Custom CSS as a string
custom_css = """
.gradio-container {
font-family: 'Inter', -apple-system, sans-serif;
}
.image-preview {
min-height: 300px;
max-height: 500px;
width: 100%;
object-fit: contain;
border-radius: 8px;
border: 2px solid #eee;
}
.tab-nav {
background: white;
padding: 1rem;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.examples-table {
font-size: 0.9rem;
}
.gr-button.gr-button-lg {
padding: 12px 24px;
font-size: 1.1rem;
}
.gr-input, .gr-select {
border-radius: 6px;
}
.gr-form {
background: white;
padding: 20px;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0,0,0,0.05);
}
.gr-panel {
border: none;
background: transparent;
}
.footer {
text-align: center;
margin-top: 2rem;
padding: 1rem;
color: #666;
}
"""
# Gradio interface with improved UI
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo"),
css=custom_css
) as demo:
gr.Markdown(
"""
# Deepseek Multimodal
### Advanced AI for Visual Understanding and Generation
This powerful multimodal AI system combines:
* **Visual Analysis**: Advanced image understanding and medical image interpretation
* **Creative Generation**: High-quality image generation from text descriptions
* **Interactive Chat**: Natural conversation about visual content
"""
)
with gr.Tabs():
# Visual Chat Tab
with gr.Tab("Visual Understanding"):
with gr.Row(equal_height=True):
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Image",
type="numpy",
elem_classes="image-preview"
)
with gr.Column(scale=1):
question_input = gr.Textbox(
label="Question or Analysis Request",
placeholder="Ask a question about the image or request detailed analysis...",
lines=3
)
with gr.Row():
und_seed_input = gr.Number(
label="Seed",
precision=0,
value=42,
container=False
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=0.95,
step=0.05,
label="Top-p",
container=False
)
temperature = gr.Slider(
minimum=0,
maximum=1,
value=0.1,
step=0.05,
label="Temperature",
container=False
)
understanding_button = gr.Button(
"Analyze Image",
variant="primary"
)
understanding_output = gr.Textbox(
label="Analysis Results",
lines=10,
show_copy_button=True
)
with gr.Accordion("Medical Analysis Examples", open=False):
gr.Examples(
examples=[
[
"""You are an AI assistant trained to analyze medical images...""",
"fundus.webp",
],
],
inputs=[question_input, image_input],
)
# Image Generation Tab
with gr.Tab("Image Generation"):
with gr.Column():
prompt_input = gr.Textbox(
label="Image Description",
placeholder="Describe the image you want to create in detail...",
lines=3
)
with gr.Row():
cfg_weight_input = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=0.5,
label="Guidance Scale",
info="Higher values create images that more closely match your prompt"
)
t2i_temperature = gr.Slider(
minimum=0,
maximum=1,
value=1.0,
step=0.05,
label="Temperature",
info="Controls randomness in generation"
)
seed_input = gr.Number(
label="Seed (Optional)",
precision=0,
value=12345,
info="Set for reproducible results"
)
generation_button = gr.Button(
"Generate Images",
variant="primary"
)
image_output = gr.Gallery(
label="Generated Images",
columns=3,
rows=2,
height=500,
object_fit="contain"
)
with gr.Accordion("Generation Examples", open=False):
gr.Examples(
examples=[
"Master shifu racoon wearing drip attire as a street gangster.",
"The face of a beautiful girl",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A glass of red wine on a reflective surface.",
"A cute and adorable baby fox with big brown eyes...",
],
inputs=prompt_input,
)
# Connect components
understanding_button.click(
multimodal_understanding,
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
outputs=understanding_output
)
generation_button.click(
fn=generate_image,
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
outputs=image_output
)
# Launch the demo
if __name__ == "__main__":
demo.launch(share=True)