File size: 4,768 Bytes
9b25f0e
8ff201d
b4fa047
 
 
8ff201d
b4fa047
 
 
 
 
 
 
 
8ff201d
b4fa047
 
 
 
 
 
 
 
 
8ff201d
b4fa047
8ff201d
b4fa047
 
 
 
 
 
c5d3fee
b4fa047
 
 
 
13669f6
 
 
 
 
 
8ff201d
13669f6
 
 
 
 
 
 
8ff201d
13669f6
 
 
 
 
 
 
8ff201d
13669f6
 
 
 
 
 
 
 
8ff201d
 
 
b4fa047
13669f6
 
8ff201d
13669f6
45159cb
 
 
b4fa047
13669f6
b4fa047
13669f6
b4fa047
13669f6
45159cb
b29a047
45159cb
13669f6
45159cb
3f86c47
13669f6
3f86c47
8ff201d
b4fa047
45159cb
13669f6
 
 
 
8ff201d
13669f6
 
b29a047
 
 
 
 
 
 
 
 
 
8ff201d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import spaces
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from huggingface_hub import snapshot_download
from pathlib import Path

# モデルのダウンロードと準備
mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistral-community/pixtral-12b-240910", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], 
                  local_dir=mistral_models_path)

# トークナイザーとモデルのロード
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)

# 推論処理
@spaces.GPU
def mistral_inference(prompt, image_url):
    completion_request = ChatCompletionRequest(
        messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
    )
    
    encoded = tokenizer.encode_chat_completion(completion_request)
    images = encoded.images
    tokens = encoded.tokens

    out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    result = tokenizer.decode(out_tokens[0])
    
    return result

# 言語によるUIラベルの設定
def get_labels(language):
    labels = {
        'en': {
            'title': "Pixtral Model Image Description",
            'text_prompt': "Text Prompt",
            'image_url': "Image URL",
            'output': "Model Output",
            'image_display': "Input Image",
            'submit': "Run Inference"
        },
        'zh': {
            'title': "Pixtral模型图像描述",
            'text_prompt': "文本提示",
            'image_url': "图片网址",
            'output': "模型输出",
            'image_display': "输入图片",
            'submit': "运行推理"
        },
        'jp': {
            'title': "Pixtralモデルによる画像説明生成",
            'text_prompt': "テキストプロンプト",
            'image_url': "画像URL",
            'output': "モデルの出力結果",
            'image_display': "入力された画像",
            'submit': "推論を実行"
        }
    }
    return labels[language]

# Gradioインターフェース
def process_input(text, image_url):
    result = mistral_inference(text, image_url)
    return result, f'<img src="{image_url}" alt="Input Image" width="300">'

def update_ui(language):
    labels = get_labels(language)
    return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit']

# 初期URL
initial_url = "https://huggingface.co./spaces/aixsatoshi/Pixtral-12B/resolve/main/llamagiant.jpg"

with gr.Blocks() as demo:
    language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
    
    title = gr.Markdown("## Pixtral Model Image Description")
    with gr.Row():
        text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
        image_input = gr.Textbox(label="Image URL", value=initial_url)  # 初期URLを設定

    # 初期画像を表示
    result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20)  # 高さ500ピクセルに相当するように調整
    image_output = gr.HTML(f'<img src="{initial_url}" alt="Input Image" width="300">')  # 入力された画像を最初から表示

    submit_button = gr.Button("Run Inference")

    submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output])


    # 言語変更時にUIラベルを更新
    language_choice.change(
        fn=update_ui, 
        inputs=[language_choice], 
        outputs=[title, text_input, image_input, result_output, image_output, submit_button]
    )

    # 例の設定
    examples = [
        ["Describe the scene.", "https://assets.st-note.com/production/uploads/images/138094970/rectangle_large_type_2_bc1a73623dc0e9bf8799832ddb4cd53e.png"],
        ["Describe the image.", "https://huggingface.co./datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"],
        ["Describe the random generated image.", "https://picsum.photos/seed/picsum/200/300"],
        ["Describe the image.", "https://picsum.photos/id/32/512/512"]
    ]

    gr.Examples(examples=examples, inputs=[text_input, image_input], label="Example Inputs")

demo.launch()