Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from transformers import AutoModel | |
import torch | |
import spaces | |
import os | |
# Load the model | |
model = AutoModel.from_pretrained("ragavsachdeva/magiv2", trust_remote_code=True).cuda().eval() | |
def read_image(image): | |
image = Image.open(image).convert("L").convert("RGB") | |
image = np.array(image) | |
return image | |
def process_images(chapter_pages, character_bank_images, character_bank_names): | |
if chapter_pages is None: | |
return [], "" | |
if character_bank_images is None: | |
character_bank_images = [] | |
character_bank_names = "" | |
if character_bank_names is None or character_bank_names == "": | |
character_bank_names = ",".join([os.path.splitext(os.path.basename(x))[0] for x in character_bank_images]) | |
chapter_pages = [read_image(image) for image in chapter_pages] | |
character_bank = { | |
"images": [read_image(image) for image in character_bank_images], | |
"names": character_bank_names.split(",") | |
} | |
with torch.no_grad(): | |
per_page_results = model.do_chapter_wide_prediction(chapter_pages, character_bank, use_tqdm=True, do_ocr=True) | |
output_images = [] | |
transcript = [] | |
for i, (image, page_result) in enumerate(zip(chapter_pages, per_page_results)): | |
output_image = model.visualise_single_image_prediction(image, page_result, filename=None) | |
output_images.append(output_image) | |
speaker_name = { | |
text_idx: page_result["character_names"][char_idx] for text_idx, char_idx in page_result["text_character_associations"] | |
} | |
for j in range(len(page_result["ocr"])): | |
if not page_result["is_essential_text"][j]: | |
continue | |
name = speaker_name.get(j, "unsure") | |
transcript.append(f"<{name}>: {page_result['ocr'][j]}") | |
transcript_text = "\n".join(transcript) | |
return output_images, transcript_text | |
# Define Gradio interface | |
chapter_pages_input = gr.Files(label="Chapter pages in chronological order.") | |
character_bank_images_input = gr.Files(label="Character reference images. If left empty, the transcript will say 'Other' for all characters.") | |
character_bank_names_input = gr.Textbox(label="Character names (comma separated). If left empty, the filenames of character images will be used.") | |
output_images = gr.Gallery(label="Output Images") | |
transcript_output = gr.Textbox(label="Transcript") | |
gr.Interface( | |
fn=process_images, | |
inputs=[chapter_pages_input, character_bank_images_input, character_bank_names_input], | |
outputs=[output_images, transcript_output], | |
title="Tails Tell Tales: Chapter-Wide Manga Transcriptions With Character Names", | |
description="Instructions: (i) Upload a sequence of manga pages, (ii) Upload a set of reference character images, (iii) Provide the names for each character image, (iv) Sit tight, this can take a couple of minutes (OCR model is slow). Note: The job will abort after 3mins, so don't upload too many images (30ish is fine).", | |
).launch() | |