import spaces import gradio as gr import numpy as np import requests import base64 import os from datetime import datetime from pytz import timezone import torch import diffusers from diffusers import DDPMPipeline from transformers import AutoTokenizer, AutoModel from utils import normalize_formula tz = timezone('EST') API_ENDPOINT = os.getenv('API_ENDPOINT') API_KEY = os.getenv('API_KEY') print (API_ENDPOINT) print (API_KEY) title = "

Markup-to-Image Diffusion Models with Scheduled Sampling

" authors = "
Yuntian Deng, Noriyuki Kojima, Alexander M. Rush
" info = '
Paper Code
' #notice = "

Notice: Due to resource constraints, we've transitioned from GPU to CPU processing for this demo, which results in significantly longer inference times. We appreciate your understanding.

" notice = "

Acknowledgment: This demo is powered by GPU resources supported by the Hugging Face Community Grant.

" # setup def setup(): img_pipe = DDPMPipeline.from_pretrained("yuntian-deng/latex2im_ss_finetunegptneo") model_type = "EleutherAI/gpt-neo-125M" #encoder = AutoModel.from_pretrained(model_type).to(device) encoder = img_pipe.unet.text_encoder if False: l = len(img_pipe.unet.down_blocks) for i in range(l): img_pipe.unet.down_blocks[i] = torch.compile(img_pipe.unet.down_blocks[i]) l = len(img_pipe.unet.up_blocks) for i in range(l): img_pipe.unet.up_blocks[i] = torch.compile(img_pipe.unet.up_blocks[i]) tokenizer = AutoTokenizer.from_pretrained(model_type, max_length=1024) eos_id = tokenizer.encode(tokenizer.eos_token)[0] def forward_encoder(latex): device = ("cuda" if torch.cuda.is_available() else "cpu") img_pipe.to(device) encoded = tokenizer(latex, return_tensors='pt', truncation=True, max_length=1024) input_ids = encoded['input_ids'] input_ids = torch.cat((input_ids, torch.LongTensor([eos_id,]).unsqueeze(0)), dim=-1) input_ids = input_ids.to(device) attention_mask = encoded['attention_mask'] attention_mask = torch.cat((attention_mask, torch.LongTensor([1,]).unsqueeze(0)), dim=-1) attention_mask = attention_mask.to(device) with torch.no_grad(): outputs = encoder(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = outputs.last_hidden_state last_hidden_state = attention_mask.unsqueeze(-1) * last_hidden_state # shouldn't be necessary return last_hidden_state return img_pipe, forward_encoder img_pipe, forward_encoder = setup() with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(authors) gr.Markdown(info) gr.Markdown(notice) with gr.Row(): with gr.Column(scale=2): textbox = gr.Textbox(label=r'Type LaTeX formula below and click "Generate"', lines=1, max_lines=1, placeholder='Type LaTeX formula here and click "Generate"', value=r'\sum_{t=1}^T\E_{y_t \sim {\tilde P(y_t| y_0)}} \left\| \frac{y_t - \sqrt{\bar{\alpha}_t}y_0}{\sqrt{1-\bar{\alpha}_t}} - \epsilon_\theta(y_t, t)\right\|^2.') submit_btn = gr.Button("Generate", elem_id="btn") with gr.Column(scale=3): slider = gr.Slider(0, 1000, value=0, label='step (out of 1000)') image = gr.Image(label="Rendered Image", show_label=False, elem_id="image") inputs = [textbox] outputs = [slider, image, submit_btn] # duration is set to default to avoid quota issues @spaces.GPU(duration=90) def infer(formula): current_time = datetime.now(tz) print (current_time, formula) data = {'formula': formula, 'api_key': API_KEY} try: formula_normalized = normalize_formula(formula) except Exception as e: print (e) formula_normalized = formula print ('normalized', formula_normalized) encoder_hidden_states = forward_encoder(formula_normalized) try: i = 0 results = [] for _, image_clean in img_pipe.run_clean(batch_size=1, generator=torch.manual_seed(0), encoder_hidden_states=encoder_hidden_states, output_type="numpy"): i += 1 image_clean = image_clean[0] image_clean = np.ascontiguousarray(image_clean) #s = base64.b64encode(image_clean).decode('ascii') #yield s q = image_clean q = q.reshape((64, 320, 3)) #print (q.min(), q.max()) yield i, q, gr.update(visible=False) yield i, q, gr.update(visible=True) #with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r: # i = 0 # for line in r.iter_lines(): # response = line.decode('ascii').strip() # r = base64.decodebytes(response.encode('ascii')) # q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3)) # i += 1 # yield i, q, gr.update(visible=False) # yield i, q, gr.update(visible=True) except Exception as e: yield 1000, 255*np.ones((64, 320, 3)), gr.update(visible=True) submit_btn.click(fn=infer, inputs=inputs, outputs=outputs, concurrency_limit=1) demo.queue(max_size=20).launch()