Spaces:
Runtime error
Runtime error
# pip install imgkit | |
# pip install html2image | |
import base64 | |
import random | |
import uuid | |
from io import BytesIO | |
import imgkit | |
import os | |
import pathlib | |
import re | |
import gradio as gr | |
import requests | |
from PIL import Image, ImageChops, ImageDraw | |
from gradio_client import Client | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Pipeline | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise Exception("HF_TOKEN environment variable is required to call remote API.") | |
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
client = Client("https://latent-consistency-super-fast-lcm-lora-sd1-5.hf.space/") | |
def init_speech_to_text_model() -> Pipeline: | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "distil-whisper/distil-medium.en" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
whisper_pipe = init_speech_to_text_model() | |
def query(payload: dict): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def generate_text(card_text: str, user_request: str) -> (str, str, str): | |
# Prompt must apply the correct chat template for the model see: | |
# https://huggingface.co./docs/transformers/main/en/chat_templating | |
prompt = f"""<|system|> | |
You create Dungeons & Dragons monsters based on the user's request. | |
# RULES | |
- In your response always generate a new monster. | |
- Only generate one monster, no other dialogue. | |
- Surround monster info in triple backticks (```). | |
- Format the monster text using headers like in the example below: | |
``` | |
Name: Jabberwock | |
Type: Medium humanoid (human), neutral evil | |
Description: A Jabberwock is a creature of the Deep Sea. | |
Stats: | |
Armor Class: 15 (breastplate) | |
Hit Points: 22 (5d8 + 5) | |
Speed: 30 ft | |
STR: 8 (-1) | |
DEX: 14 (+2) | |
CON: 12 (+1) | |
INT: 2 (-4) | |
WIS: 10 (+0) | |
CHA: 4 (-3) | |
Skills: Perception +3 | |
Senses: darkvision 60 ft., passive Perception 14 | |
Languages: โ | |
Challenge: 1/4 (50 XP) | |
Passives: | |
Legendary Resistance (3/Day): If the Jabberwock fails a saving throw, it can choose to succeed instead | |
Actions: | |
Bite: Melee Weapon Attack: +5 to hit, reach 5 ft., one target. Hit: 6 (1d8 + 3) piercing damage. | |
Claws (Recharge 5-6): Melee Weapon Attack: +5 to hit, reach 5 ft., one target. Hit: 6 (1d10 + 3) slashing damage. | |
```</s> | |
<|user|> | |
{user_request}</s> | |
<|assistant|> | |
""" | |
if card_text and card_text != starting_text: | |
prompt = f"""<|system|> | |
You edit Dungeons & Dragons monsters based on the user's request. | |
# RULES | |
- In your response always generate a new monster. | |
- Only generate one monster, no other dialogue. | |
- Surround monster info in triple backticks (```). | |
- Format the monster text using headers like in the example below: | |
``` | |
Name: Jabberwock | |
Type: Medium humanoid (human), neutral evil | |
Description: A Jabberwock is a creature of the Deep Sea. | |
Stats: | |
Armor Class: 15 (breastplate) | |
Hit Points: 22 (5d8 + 5) | |
Speed: 30 ft | |
STR: 8 (-1) | |
DEX: 14 (+2) | |
CON: 12 (+1) | |
INT: 2 (-4) | |
WIS: 10 (+0) | |
CHA: 4 (-3) | |
Skills: Perception +3 | |
Senses: darkvision 60 ft., passive Perception 14 | |
Languages: โ | |
Challenge: 1/4 (50 XP) | |
Passives: | |
Legendary Resistance (3/Day): If the Jabberwock fails a saving throw, it can choose to succeed instead | |
Actions: | |
Bite: Melee Weapon Attack: +5 to hit, reach 5 ft., one target. Hit: 6 (1d8 + 3) piercing damage. | |
Claws (Recharge 5-6): Melee Weapon Attack: +5 to hit, reach 5 ft., one target. Hit: 6 (1d10 + 3) slashing damage. | |
```</s> | |
<|user|> | |
# CARD TO EDIT | |
``` | |
{card_text} | |
``` | |
# EDIT REQUEST | |
{user_request}</s> | |
<|assistant|> | |
""" | |
print(f"Calling API with prompt:\n{prompt}") | |
params = {"max_new_tokens": 512} | |
output = query({"inputs": prompt, "parameters": params}) | |
if 'error' in output: | |
print(f'Language model call failed: {output["error"]}') | |
raise gr.Warning(f'Language model call failed: {output["error"]}') | |
print(f'API RESPONSE SIZE: {len(output[0]["generated_text"])}') | |
assistant_reply = output[0]["generated_text"].split('<|assistant|>')[1] | |
print(f'ASSISTANT REPLY:\n{assistant_reply}') | |
new_card_text = assistant_reply.split('```') | |
if len(new_card_text) > 1: | |
new_card_text = new_card_text[1].strip() + '\n' | |
else: | |
return assistant_reply, assistant_reply, None | |
return assistant_reply, new_card_text, None | |
def extract_text_for_header(text, header): | |
match = re.search(fr"{header}: (.*)", text) | |
if match is None: | |
return '' | |
return match.group(1) | |
def remove_section(html, html_class): | |
match = re.search(f'<li class="{html_class}"([\w\W])*?li>', html) | |
if match is not None: | |
html = html.replace(match.group(0), '') | |
return html | |
def format_html(monster_text, image_data): | |
print('FORMATTING MONSTER TEXT') | |
# see giffyglyph's monster maker https://giffyglyph.com/monstermaker/app/ | |
# Different Formatting style examples and some json export formats | |
card = pathlib.Path('monsterMakerTemplate.html').read_text() | |
if not isinstance(image_data, (bytes, bytearray)): | |
card = card.replace('{image_data}', f'{image_data}') | |
else: | |
card = card.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') | |
name = extract_text_for_header(monster_text, 'Name') | |
card = card.replace('{name}', name) | |
monster_type = extract_text_for_header(monster_text, 'Type') | |
card = card.replace('{monster_type}', monster_type) | |
armor_class = extract_text_for_header(monster_text, 'Armor Class') | |
card = card.replace('{armor_class}', armor_class) | |
hit_points = extract_text_for_header(monster_text, 'Hit Points') | |
card = card.replace('{hit_points}', hit_points) | |
speed = extract_text_for_header(monster_text, 'Speed') | |
card = card.replace('{speed}', speed) | |
str_stat = extract_text_for_header(monster_text, 'STR') | |
card = card.replace('{str_stat}', str_stat) | |
dex_stat = extract_text_for_header(monster_text, 'DEX') | |
card = card.replace('{dex_stat}', dex_stat) | |
con_stat = extract_text_for_header(monster_text, 'CON') | |
card = card.replace('{con_stat}', con_stat) | |
int_stat = extract_text_for_header(monster_text, 'INT') | |
card = card.replace('{int_stat}', int_stat) | |
wis_stat = extract_text_for_header(monster_text, 'WIS') | |
card = card.replace('{wis_stat}', wis_stat) | |
cha_stat = extract_text_for_header(monster_text, 'CHA') | |
card = card.replace('{cha_stat}', cha_stat) | |
saving_throws = extract_text_for_header(monster_text, 'Saving Throws') | |
card = card.replace('{saving_throws}', saving_throws) | |
if not saving_throws: | |
card = remove_section(card, 'monster-saves') | |
skills = extract_text_for_header(monster_text, 'Skills') | |
card = card.replace('{skills}', skills) | |
if not skills: | |
card = remove_section(card, 'monster-skills') | |
damage_vulnerabilities = extract_text_for_header(monster_text, 'Damage Vulnerabilities') | |
card = card.replace('{damage_vulnerabilities}', damage_vulnerabilities) | |
if not damage_vulnerabilities: | |
card = remove_section(card, 'monster-vulnerabilities') | |
damage_resistances = extract_text_for_header(monster_text, 'Damage Resistances') | |
card = card.replace('{damage_resistances}', damage_resistances) | |
if not damage_resistances: | |
card = remove_section(card, 'monster-resistances') | |
damage_immunities = extract_text_for_header(monster_text, 'Damage Immunities') | |
card = card.replace('{damage_immunities}', damage_immunities) | |
if not damage_immunities: | |
card = remove_section(card, 'monster-immunities') | |
condition_immunities = extract_text_for_header(monster_text, 'Condition Immunities') | |
card = card.replace('{condition_immunities}', condition_immunities) | |
if not condition_immunities: | |
card = remove_section(card, 'monster-conditions') | |
senses = extract_text_for_header(monster_text, 'Senses') | |
card = card.replace('{senses}', senses) | |
if not senses: | |
card = remove_section(card, 'monster-senses') | |
languages = extract_text_for_header(monster_text, 'Languages') | |
card = card.replace('{languages}', languages) | |
if not languages: | |
card = remove_section(card, 'monster-languages') | |
challenge = extract_text_for_header(monster_text, 'Challenge') | |
card = card.replace('{challenge}', challenge) | |
if not challenge: | |
card = remove_section(card, 'monster-challenge') | |
description = extract_text_for_header(monster_text, 'Description') | |
card = card.replace('{description}', description) | |
match = re.search(r"Passives:\n([\w\W]*)", monster_text) | |
if match is None: | |
passives = '' | |
else: | |
passives = match.group(1) | |
p = passives.split(':') | |
if len(p) > 1: | |
p = ":".join(p) | |
p = p.split('\n') | |
passives_data = '' | |
for x in p: | |
x = x.split(':') | |
if len(x) > 1: | |
trait = x[0] | |
if trait == "Passives": | |
continue | |
if 'Action' in trait: | |
break | |
detail = ":".join(x[1:]) | |
passives_data += f'<div class="monster-trait"><p><span class="name">{trait}</span> <span class="detail">{detail}</span></p></div>' | |
card = card.replace('{passives}', passives_data) | |
else: | |
card = card.replace('{passives}', f'<div class="monster-trait"><p>{passives}</p></div>') | |
match = re.search(r"Actions:\n([\w\W]*)", monster_text) | |
if match is None: | |
actions = '' | |
else: | |
actions = match.group(1) | |
a = actions.split(':') | |
if len(a) > 1: | |
a = ":".join(a) | |
a = a.split('\n') | |
actions_data = '' | |
for x in a: | |
x = x.split(':') | |
if len(x) > 1: | |
action = x[0] | |
if action == "Actions": | |
continue | |
if 'Passive' in action: | |
break | |
detail = ":".join(x[1:]) | |
actions_data += f'<div class="monster-action"><p><span class="name">{action}</span> <span class="detail">{detail}</span></p></div>' | |
card = card.replace('{actions}', actions_data) | |
else: | |
card = card.replace('{actions}', f'<div class="monster-action"><p>{actions}</p></div>') | |
# TODO: Legendary actions, reactions, make column count for format an option (1 or 2 column layout) | |
card = card.replace('Melee or Ranged Weapon Attack:', '<i>Melee or Ranged Weapon Attack:</i>') | |
card = card.replace('Melee Weapon Attack:', '<i>Melee Weapon Attack:</i>') | |
card = card.replace('Ranged Weapon Attack:', '<i>Ranged Weapon Attack:</i>') | |
card = card.replace('Hit:', '<i>Hit:</i>') | |
print('FORMATTING MONSTER TEXT COMPLETE') | |
return card | |
def get_savename(directory, name, extension): | |
save_name = f"{name}.{extension}" | |
i = 1 | |
while os.path.exists(os.path.join(directory, save_name)): | |
save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" | |
i += 1 | |
return save_name | |
def trim(im, border): | |
bg = Image.new(im.mode, im.size, border) | |
diff = ImageChops.difference(im, bg) | |
bbox = diff.getbbox() | |
if bbox: | |
return im.crop(bbox) | |
def crop_background(image): | |
white = (255, 255, 255) | |
ImageDraw.floodfill(image, (image.size[0] - 1, 0), white, thresh=50) | |
image = trim(image, white) | |
return image | |
def html_to_png(card_name, html): | |
save_name = get_savename('rendered_cards', card_name, 'png') | |
print('CONVERTING HTML CARD TO PNG IMAGE') | |
path = os.path.join('rendered_cards', save_name) | |
try: | |
css = ['./css/mana.css', './css/keyrune.css', | |
'./css/mtg_custom.css'] | |
imgkit.from_string(html, path, {"xvfb": "", "enable-local-file-access": ""}, css=css) | |
except Exception as e: | |
try: | |
# For Windows local, requires 'html2image' package from pip. | |
from html2image import Html2Image | |
rendered_card_dir = 'rendered_cards' | |
hti = Html2Image(output_path=rendered_card_dir) | |
paths = hti.screenshot(html_str=html, | |
css_file='monstermaker.css', | |
save_as=save_name, size=(800, 1440)) | |
print(paths) | |
path = paths[0] | |
except: | |
pass | |
print('OPENING IMAGE FROM FILE') | |
img = Image.open(path).convert("RGB") | |
print('CROPPING BACKGROUND') | |
img = crop_background(img) | |
print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') | |
return img | |
def get_initial_card(): | |
return Image.open('SampleCard.png') | |
def pil_to_base64(image): | |
print('CONVERTING PIL IMAGE TO BASE64 STRING') | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') | |
return img_str | |
def generate_card(image: str, card_text: str): | |
image_data = pil_to_base64(Image.open(image)) | |
html = format_html(card_text, image_data) | |
pattern = re.compile('Name: (.*)') | |
name = pattern.findall(card_text)[0] | |
card = html_to_png(name, html) | |
return card | |
def transcribe(audio: str) -> (str, str): | |
result = whisper_pipe(audio) | |
return result["text"], None | |
starting_text = """Name: Jabberwock | |
Type: Medium humanoid (human), neutral evil | |
Description: A Jabberwock is a creature of the Deep Sea. | |
Stats: | |
Armor Class: 15 (breastplate) | |
Hit Points: 22 (5d8 + 5) | |
Speed: 30 ft | |
STR: 8 (-1) | |
DEX: 14 (+2) | |
CON: 12 (+1) | |
INT: 2 (-4) | |
WIS: 10 (+0) | |
CHA: 4 (-3) | |
Skills: Perception +3 | |
Senses: darkvision 60 ft., passive Perception 14 | |
Languages: โ | |
Challenge: 1/4 (50 XP) | |
Passives: | |
Legendary Resistance (3/Day): If the Jabberwock fails a saving throw, it can choose to succeed instead | |
Actions: | |
Bite: Melee Weapon Attack: +5 to hit, reach 5 ft., one target. Hit: 6 (1d8 + 3) piercing damage. | |
Claws (Recharge 5-6): Melee Weapon Attack: +5 to hit, reach 5 ft., one target. Hit: 6 (1d10 + 3) slashing damage.""" | |
def generate_image(card_text: str): | |
pattern = re.compile('Name: (.*)') | |
name = pattern.findall(card_text)[0] | |
pattern = re.compile('Type: (.*)') | |
card_type = pattern.findall(card_text)[0] | |
prompt = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" | |
print(f'Calling image generation with prompt: {prompt}') | |
try: | |
result = client.predict( | |
prompt, # str in 'parameter_5' Textbox component | |
0.3, # float (numeric value between 0.0 and 5) in 'Guidance' Slider component | |
4, # float (numeric value between 2 and 10) in 'Steps' Slider component | |
random.randint(0, 12013012031030), | |
# float (numeric value between 0 and 12013012031030) in 'Seed' Slider component | |
api_name="/predict" | |
) | |
print(result) | |
return result | |
except Exception as e: | |
print(f'Failed to generate image from client: {e}') | |
return 'placeholder.png' | |
def add_hotkeys() -> str: | |
return pathlib.Path("hotkeys.js").read_text() | |
with gr.Blocks(title='MonsterGen') as demo: | |
gr.Markdown("# ๐น MonsterGenV2") | |
gr.Markdown("## Generate and Edit D&D Monsters with a Chat Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
audio_in = gr.Microphone(label="Record a voice request (click or press ctrl + ` to start/stop)", | |
type='filepath', elem_classes=["record-btn"]) | |
prompt_in = gr.Textbox(label="Or type a text request and press Enter", interactive=True, | |
placeholder="Need an idea? Try one of these:\n- Create a creature card named 'WiFi Elemental'\n- Make it an instant\n- Change the color") | |
with gr.Accordion(label='๐ค Chat Assistant Response', open=False): | |
bot_text = gr.TextArea(label='Response', interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
in_text = gr.TextArea(label="Card Text (Shift+Enter to submit)", value=starting_text) | |
gen_image_button = gr.Button('๐ผ๏ธ Generate Card Image') | |
in_image = gr.Image(label="Card Image (256px x 256px)", type='filepath', value='placeholder.png') | |
render_button = gr.Button('๐ด Render Card', variant="primary") | |
gr.ClearButton([audio_in, prompt_in, in_text, in_image]) | |
with gr.Column(): | |
out_image = gr.Image(label="Rendered Card", value=get_initial_card()) | |
transcribe_params = {'fn': transcribe, 'inputs': [audio_in], 'outputs': [prompt_in, audio_in]} | |
generate_text_params = {'fn': generate_text, 'inputs': [in_text, prompt_in], | |
'outputs': [bot_text, in_text, audio_in]} | |
generate_image_params = {'fn': generate_image, 'inputs': [in_text], 'outputs': [in_image]} | |
generate_card_params = {'fn': generate_card, 'inputs': [in_image, in_text], 'outputs': [out_image]} | |
# Shift + Enter to submit text in TextAreas | |
audio_in.stop_recording(**transcribe_params).then(**generate_text_params).then(**generate_image_params).then( | |
**generate_card_params) | |
prompt_in.submit(**generate_text_params).then(**generate_image_params).then(**generate_card_params) | |
in_text.submit(**generate_card_params) | |
render_button.click(**generate_card_params) | |
gen_image_button.click(**generate_image_params).then(**generate_card_params) | |
demo.load(None, None, None, js=add_hotkeys()) | |
if __name__ == "__main__": | |
demo.queue().launch(favicon_path="favicon-96x96.png") | |