Magiv2-Demo / app.py
ragavsachdeva's picture
Update app.py
9d0818e verified
import gradio as gr
from PIL import Image
import numpy as np
from transformers import AutoModel
import torch
import spaces
import os
# Load the model
model = AutoModel.from_pretrained("ragavsachdeva/magiv2", trust_remote_code=True).cuda().eval()
def read_image(image):
image = Image.open(image).convert("L").convert("RGB")
image = np.array(image)
return image
@spaces.GPU(duration=180)
def process_images(chapter_pages, character_bank_images, character_bank_names):
if chapter_pages is None:
return [], ""
if character_bank_images is None:
character_bank_images = []
character_bank_names = ""
if character_bank_names is None or character_bank_names == "":
character_bank_names = ",".join([os.path.splitext(os.path.basename(x))[0] for x in character_bank_images])
chapter_pages = [read_image(image) for image in chapter_pages]
character_bank = {
"images": [read_image(image) for image in character_bank_images],
"names": character_bank_names.split(",")
}
with torch.no_grad():
per_page_results = model.do_chapter_wide_prediction(chapter_pages, character_bank, use_tqdm=True, do_ocr=True)
output_images = []
transcript = []
for i, (image, page_result) in enumerate(zip(chapter_pages, per_page_results)):
output_image = model.visualise_single_image_prediction(image, page_result, filename=None)
output_images.append(output_image)
speaker_name = {
text_idx: page_result["character_names"][char_idx] for text_idx, char_idx in page_result["text_character_associations"]
}
for j in range(len(page_result["ocr"])):
if not page_result["is_essential_text"][j]:
continue
name = speaker_name.get(j, "unsure")
transcript.append(f"<{name}>: {page_result['ocr'][j]}")
transcript_text = "\n".join(transcript)
return output_images, transcript_text
# Define Gradio interface
chapter_pages_input = gr.Files(label="Chapter pages in chronological order.")
character_bank_images_input = gr.Files(label="Character reference images. If left empty, the transcript will say 'Other' for all characters.")
character_bank_names_input = gr.Textbox(label="Character names (comma separated). If left empty, the filenames of character images will be used.")
output_images = gr.Gallery(label="Output Images")
transcript_output = gr.Textbox(label="Transcript")
gr.Interface(
fn=process_images,
inputs=[chapter_pages_input, character_bank_images_input, character_bank_names_input],
outputs=[output_images, transcript_output],
title="Tails Tell Tales: Chapter-Wide Manga Transcriptions With Character Names",
description="Instructions: (i) Upload a sequence of manga pages, (ii) Upload a set of reference character images, (iii) Provide the names for each character image, (iv) Sit tight, this can take a couple of minutes (OCR model is slow). Note: The job will abort after 3mins, so don't upload too many images (30ish is fine).",
).launch()