Spaces:
Paused
Paused
File size: 4,797 Bytes
1efd233 dfc92d5 89ace7e c1fc3a9 cc04f60 1e1efc2 27bcfa0 c1fc3a9 27bcfa0 1efd233 dce3abc 2ccc88d 4b29566 c1fc3a9 028d122 6bf2756 4b29566 40ff259 76f6945 40ff259 4b29566 40ff259 4b29566 1efd233 c1fc3a9 4b29566 c1fc3a9 88a0be3 c1fc3a9 4b29566 c1fc3a9 4b29566 0d18b6e c1fc3a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import os, requests
import torch, torchvision, einops
import spaces
import subprocess
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, PaliGemmaForConditionalGeneration, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration
from huggingface_hub import login
# Install required package
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN, add_to_git_credential=True)
# Cache for storing loaded models and their summaries
model_cache = {}
# Function to get the model summary
@spaces.GPU
def get_model_summary(model_name):
if model_name in model_cache:
return model_cache[model_name], ""
try:
# Fetch the config.json file
config_url = f"https://huggingface.co./{model_name}/raw/main/config.json"
headers = {"Authorization": f"Bearer {hf_token}"}
response = requests.get(config_url, headers=headers)
response.raise_for_status()
config = response.json()
architecture = config["architectures"][0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Select the correct model class based on the architecture
if architecture == "LlavaNextForConditionalGeneration":
from transformers import LlavaNextForConditionalGeneration
model = LlavaNextForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "LlavaForConditionalGeneration":
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "PaliGemmaForConditionalGeneration":
from transformers import PaliGemmaForConditionalGeneration
model = PaliGemmaForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "Idefics2ForConditionalGeneration":
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True).to(device)
elif architecture == "MiniCPMV":
from transformers import MiniCPMV
model = MiniCPMV.from_pretrained(model_name, trust_remote_code=True).to(device)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
model_summary = str(model)
model_cache[model_name] = model_summary
return model_summary, ""
except Exception as e:
return "", str(e)
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
textbox = gr.Textbox(label="Model Name", placeholder="Enter the model name here OR select example below...", lines=1)
gr.Markdown("### Vision Models")
vision_examples = gr.Examples(
examples=[
["microsoft/llava-med-v1.5-mistral-7b"],
["llava-hf/llava-v1.6-mistral-7b-hf"],
["xtuner/llava-phi-3-mini-hf"],
["xtuner/llava-llama-3-8b-v1_1-transformers"],
["vikhyatk/moondream2"],
["openbmb/MiniCPM-Llama3-V-2_5"],
["microsoft/Phi-3-vision-128k-instruct"],
["google/paligemma-3b-mix-224"],
["HuggingFaceM4/idefics2-8b-chatty"]
],
inputs=textbox
)
gr.Markdown("### Other Models")
other_examples = gr.Examples(
examples=[
["google/gemma-7b"],
["microsoft/Phi-3-mini-4k-instruct"],
["meta-llama/Meta-Llama-3-8B"],
["mistralai/Mistral-7B-Instruct-v0.3"]
],
inputs=textbox
)
submit_button = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)
def handle_click(model_name):
model_summary, error_message = get_model_summary(model_name)
return model_summary, error_message
submit_button.click(fn=handle_click, inputs=textbox, outputs=[output, error_output])
# Launch the interface
demo.launch()
|