sflindrs's picture
Update app.py
d8277be verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import spaces
import json
import re
import deepl
# Load the processor and model
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
@spaces.GPU()
def wrap_json_in_markdown(text):
result = []
stack = []
json_start = None
in_json = False
i = 0
while i < len(text):
char = text[i]
if char in ['{', '[']:
if not in_json:
json_start = i
in_json = True
stack.append(char)
else:
stack.append(char)
elif char in ['}', ']'] and in_json:
if not stack:
# Unbalanced bracket, reset
in_json = False
json_start = None
else:
last = stack.pop()
if (last == '{' and char != '}') or (last == '[' and char != ']'):
# Mismatched brackets
in_json = False
json_start = None
if in_json and not stack:
# Potential end of JSON
json_str = text[json_start:i+1]
try:
# Try to parse the JSON to ensure it's valid
parsed = json.loads(json_str)
# Wrap in Markdown code block
wrapped = f"\n```json\n{json.dumps(parsed, indent=4)}\n```\n"
result.append(text[:json_start]) # Append text before JSON
result.append(wrapped) # Append wrapped JSON
text = text[i+1:] # Update the remaining text
i = -1 # Reset index
except json.JSONDecodeError:
# Not valid JSON, continue searching
pass
in_json = False
json_start = None
i += 1
result.append(text) # Append any remaining text
return ''.join(result)
def decode_unicode_sequences(unicode_seq):
"""
Decodes a sequence of Unicode escape sequences (e.g., \\u4F60\\u597D) to actual characters.
Args:
unicode_seq (str): A string containing Unicode escape sequences.
Returns:
str: The decoded Unicode string.
"""
# Regular expression to find \uXXXX
unicode_escape_pattern = re.compile(r'\\u([0-9a-fA-F]{4})')
# Function to replace each \uXXXX with the corresponding character
def replace_match(match):
hex_value = match.group(1)
return chr(int(hex_value, 16))
# Decode all \uXXXX sequences
decoded = unicode_escape_pattern.sub(replace_match, unicode_seq)
return decoded
def is_mandarin(text):
"""
Detects if the given text is in Mandarin using Unicode ranges.
Args:
text (str): The text to check.
Returns:
bool: True if the text contains Chinese characters, False otherwise.
"""
# Chinese Unicode ranges
for char in text:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def translate_to_english_deepl(text, api_key):
"""
Translates Mandarin text to English using DeepL API.
Args:
text (str): The Mandarin text to translate.
api_key (str): Your DeepL API authentication key.
Returns:
str: The translated English text.
"""
url = "https://api.deepl.com/v2/translate"
params = {
"auth_key": api_key,
"text": text,
"source_lang": "ZH",
"target_lang": "EN"
}
# try:
# response = requests.post(url, data=params)
# response.raise_for_status()
# result = response.json()
# return result['translations'][0]['text']
# except requests.exceptions.RequestException as e:
# print(f"DeepL Translation error: {e}")
# return text # Return the original text if translation fails
# auth_key = api_key # Replace with your key
# translator = deepl.Translator(auth_key)
# result = translator.translate_text("Hello, world!", target_lang="FR")
# print(result.text) # "Bonjour, le monde !"
try:
auth_key = api_key # Replace with your key
translator = deepl.Translator(auth_key)
result = translator.translate_text(text, source_lang="ZH", target_lang="EN-US")
# print(result.text)
return result.text
except requests.exceptions.RequestException as e:
print(f"DeepL Translation error: {e}")
return text # Return the original text if translation fails
def process_text_deepl(input_string, api_key):
"""
Processes the input string to find Unicode escape sequences representing Mandarin words,
translates them to English using DeepL, and replaces them accordingly.
Args:
input_string (str): The original string containing Unicode escape sequences.
api_key (str): Your DeepL API authentication key.
Returns:
str: The processed string with translations where applicable.
"""
# Regular expression to find groups of consecutive \uXXXX sequences
unicode_word_pattern = re.compile(r'(?:\\u[0-9a-fA-F]{4})+')
# Function to process each matched Unicode word
def process_match(match):
unicode_seq = match.group(0)
decoded_word = decode_unicode_sequences(unicode_seq)
if is_mandarin(decoded_word):
translated = translate_to_english_deepl(decoded_word, api_key)
return f"{translated} ({decoded_word})"
else:
# If not Mandarin, return the original sequence
return unicode_seq
# Substitute all matched Unicode words with their translations if applicable
processed_string = unicode_word_pattern.sub(process_match, input_string)
return processed_string
def process_image_and_text(image, text):
# Process the image and text
inputs = processor.process(
images=[Image.fromarray(image)],
text=text
)
# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=1024, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
generated_text_w_json_wrapper = wrap_json_in_markdown(generated_text)
generated_text_w_unicode_mdn = process_text_deepl(generated_text_w_json_wrapper, "a5b1749b-7112-4c2d-81a3-33ea18478bb4:fx")
return generated_text_w_json_wrapper
def chatbot(image, text, history):
if image is None:
return history + [("Please upload an image first.", None)]
response = process_image_and_text(image, text)
history.append({"role": "user", "content": text})
history.append({"role": "assistant", "content": response})
return history
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
with gr.Row():
image_input = gr.Image(type="numpy")
chatbot_output = gr.Chatbot(type="messages")
text_input = gr.Textbox(placeholder="Ask a question about the image...")
submit_button = gr.Button("Submit")
state = gr.State([])
submit_button.click(
chatbot,
inputs=[image_input, text_input, state],
outputs=[chatbot_output]
)
text_input.submit(
chatbot,
inputs=[image_input, text_input, state],
outputs=[chatbot_output]
)
demo.launch()