File size: 4,610 Bytes
f6fe860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
import numpy as np
import spaces  

# Load the model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'

vl_gpt = AutoModelForCausalLM.from_pretrained(
    model_path,
    language_config=language_config,
    trust_remote_code=True
)
if torch.cuda.is_available():
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
    vl_gpt = vl_gpt.to(torch.float16)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Helper functions
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, patch_size=16):
    torch.cuda.empty_cache()

    tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
    for i in range(parallel_size * 2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id
    
    inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
    generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int).to(cuda_device)

    pkv = None
    for i in range(576):
        with torch.no_grad():
            outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
            pkv = outputs.past_key_values
            hidden_states = outputs.last_hidden_state
            logits = vl_gpt.gen_head(hidden_states[:, -1, :])

            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]
            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)

            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
            img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

    patches = vl_gpt.gen_vision_model.decode_code(
        generated_tokens.to(dtype=torch.int),
        shape=[parallel_size, 8, width // patch_size, height // patch_size]
    )
    return patches

def unpack(patches, width, height, parallel_size=5):
    patches = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    patches = np.clip((patches + 1) / 2 * 255, 0, 255)

    images = [Image.fromarray(patches[i].astype(np.uint8)) for i in range(parallel_size)]
    return images

@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
    torch.cuda.empty_cache()
    
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)

    width, height, parallel_size = 384, 384, 5

    messages = [
        {'role': '<|User|>', 'content': prompt},
        {'role': '<|Assistant|>', 'content': ''}
    ]

    text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt=''
    )
    text += vl_chat_processor.image_start_tag

    input_ids = torch.LongTensor(tokenizer.encode(text))
    patches = generate(input_ids, width, height, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=parallel_size)

    return unpack(patches, width, height, parallel_size)

# Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Text-to-Image Generation")
        
        prompt_input = gr.Textbox(label="Prompt (describe the image)")
        seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
        guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
        temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)

        generate_button = gr.Button("Generate Images")
        output_gallery = gr.Gallery(label="Generated Images", columns=2, height=300)

        generate_button.click(
            generate_image,
            inputs=[prompt_input, seed_input, guidance_slider, temperature_slider],
            outputs=output_gallery
        )

    return demo

demo = create_interface()
demo.launch(share=True)