Spaces:
Running
on
L4
Running
on
L4
import os | |
# add a command for installing flash-attn | |
os.system('pip install flash-attn --no-build-isolation') | |
os.system("pip install gradio==4.44.1") | |
import pygame | |
import numpy as np | |
import gradio as gr | |
import time | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
import re | |
import random | |
pygame.mixer.quit() # Disable sound | |
# Constants | |
WIDTH, HEIGHT = 800, 800 | |
GRID_SIZE = 80 | |
WHITE = (255, 255, 255) | |
GREEN = (34, 139, 34) # Forest green - more like an apple | |
RED = (200, 50, 50) | |
BLACK = (0, 0, 0) | |
GRAY = (128, 128, 128) | |
YELLOW = (218, 165, 32) # Golden yellow color | |
# Directions | |
UP = (0, -1) | |
DOWN = (0, 1) | |
LEFT = (-1, 0) | |
RIGHT = (1, 0) | |
STATIC = (0, 0) | |
ACTIONS = ["up", "down", "left", "right", "static"] | |
# Load AI Model | |
dtype = torch.bfloat16 | |
magma_model_id = "microsoft/Magma-8B" | |
magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=dtype) | |
magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True) | |
magam_model.to("cuda") | |
magma_img = pygame.image.load("./assets/images/magma_game_thin.png") | |
magma_img = pygame.transform.scale(magma_img, (GRID_SIZE, GRID_SIZE)) | |
class MagmaFindGPU: | |
def __init__(self): | |
self.reset() | |
self.step_count = 0 | |
def reset(self): | |
self.snake = [(5, 5)] | |
self.direction = RIGHT | |
self.score = 0 | |
self.game_over = False | |
self.step_count = 0 | |
self.place_target() | |
def place_target(self): | |
while True: | |
target_x = np.random.randint(1, WIDTH // GRID_SIZE - 1) | |
target_y = np.random.randint(1, HEIGHT // GRID_SIZE - 1) | |
if (target_x, target_y) not in self.snake: | |
self.target = (target_x, target_y) | |
break | |
def step(self, action): | |
if action == "up": | |
self.direction = UP | |
elif action == "down": | |
self.direction = DOWN | |
elif action == "left": | |
self.direction = LEFT | |
elif action == "right": | |
self.direction = RIGHT | |
elif action == "static": | |
self.direction = STATIC | |
if self.game_over: | |
self.reset() | |
return self.render(), self.score | |
new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1]) | |
if new_head[0] < 0 or new_head[1] < 0 or new_head[0] >= WIDTH // GRID_SIZE or new_head[1] >= HEIGHT // GRID_SIZE: | |
self.game_over = True | |
return self.render(), self.score | |
self.snake = [new_head] # Keep only the head (single block snake) | |
self.step_count += 1 | |
# Check if the target is covered by four surrounding squares | |
head_x, head_y = self.snake[0] | |
neighbors = set([(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)]) | |
if neighbors.issuperset(set([self.target])): | |
self.score += 1 | |
self.place_target() | |
return self.render(), self.score | |
def render(self): | |
pygame.init() | |
surface = pygame.Surface((WIDTH, HEIGHT)) | |
surface.fill(BLACK) | |
head_x, head_y = self.snake[0] | |
surface.blit(magma_img, (head_x * GRID_SIZE, head_y * GRID_SIZE)) | |
# pygame.draw.rect(surface, RED, (self.snake[0][0] * GRID_SIZE, self.snake[0][1] * GRID_SIZE, GRID_SIZE, GRID_SIZE)) | |
pygame.draw.rect(surface, GREEN, (self.target[0] * GRID_SIZE, self.target[1] * GRID_SIZE, GRID_SIZE, GRID_SIZE)) | |
# Draw four surrounding squares with labels | |
head_x, head_y = self.snake[0] | |
neighbors = [(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)] | |
labels = ["1", "2", "3", "4"] | |
font = pygame.font.Font(None, 48) | |
# clone surface | |
surface_nomark = surface.copy() | |
for i, (nx, ny) in enumerate(neighbors): | |
if 0 <= nx < WIDTH // GRID_SIZE and 0 <= ny < HEIGHT // GRID_SIZE: | |
pygame.draw.rect(surface, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE) | |
# pygame.draw.rect(surface_nomark, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE) | |
text = font.render(labels[i], True, WHITE) | |
text_rect = text.get_rect(center=(nx * GRID_SIZE + GRID_SIZE // 2, ny * GRID_SIZE + GRID_SIZE // 2)) | |
surface.blit(text, text_rect) | |
return np.array(pygame.surfarray.array3d(surface_nomark)).swapaxes(0, 1), np.array(pygame.surfarray.array3d(surface)).swapaxes(0, 1) | |
def get_state(self): | |
return self.render() | |
game = MagmaFindGPU() | |
def play_game(): | |
state, state_som = game.get_state() | |
pil_img = Image.fromarray(state_som) | |
convs = [ | |
{"role": "system", "content": "You are an agent that can see, talk, and act. Avoid hitting the wall."}, | |
{"role": "user", "content": "<image_start><image><image_end>\nWhich mark is closer to green block? Answer with a single number."}, | |
] | |
prompt = magma_processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True) | |
inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt") | |
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0) | |
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0) | |
inputs = inputs.to("cuda").to(dtype) | |
generation_args = { | |
"max_new_tokens": 10, | |
"temperature": 0.3, | |
"do_sample": True, | |
"use_cache": True, | |
"num_beams": 1, | |
} | |
with torch.inference_mode(): | |
generate_ids = magam_model.generate(**inputs, **generation_args) | |
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :] | |
action = magma_processor.decode(generate_ids[0], skip_special_tokens=True).strip() | |
# extract mark id fro action use re | |
match = re.search(r'\d+', action) | |
if match: | |
action = match.group(0) | |
if action.isdigit() and 1 <= int(action) <= 4: | |
action = ACTIONS[int(action) - 1] | |
else: | |
# random choose one from the pool | |
action = random.choice(ACTIONS[:-1]) | |
else: | |
action = random.choice(ACTIONS[:-1]) | |
img, score = game.step(action) | |
img = img[0] | |
return img, f"Score: {score}" | |
def reset_game(): | |
game.reset() | |
return game.render()[0], "Score: 0" | |
MARKDOWN = """ | |
<div align="center"> | |
<img src="./assets/images/logo.png" alt="Magma Logo" style="margin-right: 5px; height: 80px;margin-top: -10px;"> | |
<h2>Magma: A Foundation Model for Multimodal AI Agents</h2> | |
\[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] \[[Project Page](https://microsoft.github.io/Magma/)\] \[[Github Repo](https://github.com/microsoft/Magma)\] \[[Hugging Face Model](https://huggingface.co./microsoft/Magma-8B)\] | |
This demo is powered by [Gradio](https://gradio.app/). | |
<b>Goal: Collects the green blocks by automatically moving up, down, left and right.</b> | |
</div> | |
""" | |
with gr.Blocks() as interface: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
image_output = gr.Image(label="Game Screen") | |
with gr.Column(): | |
score_output = gr.Text(label="Score", elem_classes="large-text") | |
gr.HTML(""" | |
<style> | |
.large-text textarea { | |
font-size: 24px !important; | |
} | |
</style> | |
""") | |
start_btn = gr.Button("Start/Reset Game") | |
interface.load(fn=play_game, every=1, inputs=[], outputs=[image_output, score_output]) | |
start_btn.click(fn=reset_game, inputs=[], outputs=[image_output, score_output]) | |
interface.launch() | |