File size: 5,576 Bytes
bc26ef9
6a29418
 
 
 
 
4b0034d
 
 
b8af58f
 
 
 
 
b498cbf
 
4b0034d
6a29418
 
 
 
8fbd4bf
 
c23213e
de88b8c
7fbd132
6bcf12d
81efcaa
5df9f05
6a29418
b8af58f
 
5c51951
b8af58f
 
 
 
2a15275
b8af58f
 
 
 
 
 
 
 
83985af
b8af58f
ea4d143
 
b8af58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de88b8c
 
7fbd132
 
81efcaa
de88b8c
 
13b7eb3
de88b8c
 
 
 
 
 
83985af
101ae48
a3459d2
de88b8c
4b0034d
 
de88b8c
b498cbf
 
 
 
 
 
 
b8af58f
6e96524
b8af58f
 
 
 
 
 
 
 
 
e6e56d8
8333d6b
958ff1e
8333d6b
b8af58f
 
 
 
 
 
22e068f
958ff1e
 
6e96524
958ff1e
fcba90f
e23839a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import spaces
import gradio as gr
import numpy as np
import requests
import base64
import os
from datetime import datetime
from pytz import timezone

import torch
import diffusers
from diffusers import DDPMPipeline
from transformers import AutoTokenizer, AutoModel

from utils import normalize_formula

tz = timezone('EST')

API_ENDPOINT = os.getenv('API_ENDPOINT')
API_KEY = os.getenv('API_KEY')

print (API_ENDPOINT)
print (API_KEY)

title = "<h1><center>Markup-to-Image Diffusion Models with Scheduled Sampling</center></h1>"
authors = "<center>Yuntian Deng, Noriyuki Kojima, Alexander M. Rush</center>"
info = '<center><a href="https://openreview.net/pdf?id=81VJDmOE2ol">Paper</a> <a href="https://github.com/da03/markup2im">Code</a></center>'
#notice = "<p><center><strong>Notice:</strong> Due to resource constraints, we've transitioned from GPU to CPU processing for this demo, which results in significantly longer inference times. We appreciate your understanding.</center></p>"
notice = "<p><center>Acknowledgment: This demo is powered by GPU resources supported by the Hugging Face Community Grant.</center></p>"


# setup
def setup():
    img_pipe = DDPMPipeline.from_pretrained("yuntian-deng/latex2im_ss_finetunegptneo")
    model_type = "EleutherAI/gpt-neo-125M"
    #encoder = AutoModel.from_pretrained(model_type).to(device)
    encoder = img_pipe.unet.text_encoder
    if False:
        l = len(img_pipe.unet.down_blocks)
        for i in range(l):
            img_pipe.unet.down_blocks[i] = torch.compile(img_pipe.unet.down_blocks[i])
        l = len(img_pipe.unet.up_blocks)
        for i in range(l):
            img_pipe.unet.up_blocks[i] = torch.compile(img_pipe.unet.up_blocks[i])
    tokenizer = AutoTokenizer.from_pretrained(model_type, max_length=1024)
    eos_id = tokenizer.encode(tokenizer.eos_token)[0]
    
    def forward_encoder(latex):
        device = ("cuda" if torch.cuda.is_available() else "cpu")
        img_pipe.to(device)
        encoded = tokenizer(latex, return_tensors='pt', truncation=True, max_length=1024)
        input_ids = encoded['input_ids']
        input_ids = torch.cat((input_ids, torch.LongTensor([eos_id,]).unsqueeze(0)), dim=-1)
        input_ids = input_ids.to(device)
        attention_mask = encoded['attention_mask']
        attention_mask = torch.cat((attention_mask, torch.LongTensor([1,]).unsqueeze(0)), dim=-1)
        attention_mask = attention_mask.to(device)
        with torch.no_grad():
            outputs = encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = outputs.last_hidden_state 
            last_hidden_state = attention_mask.unsqueeze(-1) * last_hidden_state # shouldn't be necessary
        return last_hidden_state
    return img_pipe, forward_encoder

img_pipe, forward_encoder = setup()

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(authors)
    gr.Markdown(info)
    gr.Markdown(notice)
    with gr.Row():
        with gr.Column(scale=2):
            textbox = gr.Textbox(label=r'Type LaTeX formula below and click "Generate"', lines=1, max_lines=1, placeholder='Type LaTeX formula here and click "Generate"', value=r'\sum_{t=1}^T\E_{y_t \sim {\tilde P(y_t| y_0)}} \left\| \frac{y_t - \sqrt{\bar{\alpha}_t}y_0}{\sqrt{1-\bar{\alpha}_t}} - \epsilon_\theta(y_t, t)\right\|^2.')
            submit_btn = gr.Button("Generate", elem_id="btn")
        with gr.Column(scale=3):
            slider = gr.Slider(0, 1000, value=0, label='step (out of 1000)')
            image = gr.Image(label="Rendered Image", show_label=False, elem_id="image")
    inputs = [textbox]
    outputs = [slider, image, submit_btn]

    # duration is set to default to avoid quota issues
    @spaces.GPU(duration=90)
    def infer(formula):
        current_time = datetime.now(tz)
        print (current_time, formula)
        data = {'formula': formula, 'api_key': API_KEY}
        try:
            formula_normalized = normalize_formula(formula)
        except Exception as e:
            print (e)
            formula_normalized = formula
        print ('normalized', formula_normalized)
        encoder_hidden_states = forward_encoder(formula_normalized)

        try:
            i = 0
            results = []
            for _, image_clean in img_pipe.run_clean(batch_size=1, generator=torch.manual_seed(0), encoder_hidden_states=encoder_hidden_states, output_type="numpy"):
                i += 1
                image_clean = image_clean[0]
                image_clean = np.ascontiguousarray(image_clean)
                #s = base64.b64encode(image_clean).decode('ascii')
                #yield s
                q = image_clean
                q = q.reshape((64, 320, 3))
                #print (q.min(), q.max())
                yield i, q, gr.update(visible=False)
            yield i, q, gr.update(visible=True)
            #with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r:
            #    i = 0
            #    for line in r.iter_lines():
            #        response = line.decode('ascii').strip()
            #        r = base64.decodebytes(response.encode('ascii'))
            #        q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3))
            #        i += 1
            #        yield i, q, gr.update(visible=False)
            #    yield i, q, gr.update(visible=True)
        except Exception as e:
            yield 1000, 255*np.ones((64, 320, 3)), gr.update(visible=True)
    submit_btn.click(fn=infer, inputs=inputs, outputs=outputs, concurrency_limit=1)
demo.queue(max_size=20).launch()