Niki Zhang commited on
Commit
2dc11df
·
verified ·
1 Parent(s): aea9e97

Update app.py

Browse files

liked and disliked button
recommendation system updates

Files changed (1) hide show
  1. app.py +1854 -202
app.py CHANGED
@@ -1,226 +1,1878 @@
1
- # Copyright (c) Microsoft
2
- # Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
3
-
4
  import os
 
 
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import uuid
8
- from PIL import Image, ImageDraw, ImageOps
9
- import numpy as np
10
- import argparse
11
- import inspect
12
 
13
- from langchain.agents.initialize import initialize_agent
14
- from langchain.agents.tools import Tool
15
- from langchain.chains.conversation.memory import ConversationBufferMemory
16
- from langchain.llms.openai import OpenAIChat
17
  import torch
18
- from PIL import Image, ImageDraw, ImageOps
19
- from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # openai.api_version = '2020-11-07'
22
- os.environ["OPENAI_API_VERSION"] = '2020-11-07'
 
23
 
24
- VISUAL_CHATGPT_PREFIX = """
25
- I want you to act as an art connoisseur, providing in-depth and insightful analysis on various artworks. Your responses should reflect a deep understanding of art history, techniques, and cultural contexts, offering users a rich and nuanced perspective.
 
26
 
27
- You can engage in natural-sounding conversations, generate human-like text based on input, and provide relevant, coherent responses on art-related topics."""
28
 
 
 
 
 
29
 
30
- # TOOLS:
31
- # ------
 
 
32
 
33
- # Visual ChatGPT has access to the following tools:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
 
 
 
 
 
 
 
 
 
 
36
 
37
- "Thought: Do I need to use a tool? Yes
38
- Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
39
- Action Input: the input to the action
40
- Observation: the result of the action"
41
 
42
- When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
 
 
 
 
 
 
 
43
 
44
- "Thought: Do I need to use a tool? No
45
- {ai_prefix}: [your response here]"
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  """
48
 
49
- VISUAL_CHATGPT_SUFFIX = """
50
- Begin Chatting!
51
-
52
- Previous conversation history:
53
- {chat_history}
54
-
55
- New input: {input}
56
- As a language model, you must repeatly to use VQA tools to observe images. You response should be consistent with the outputs of the VQA tool instead of imagination. Do not repeat asking the same question.
57
-
58
- Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
59
-
60
- os.makedirs('chat_image', exist_ok=True)
61
-
62
-
63
- def prompts(name, description):
64
- def decorator(func):
65
- func.name = name
66
- func.description = description
67
- return func
68
- return decorator
69
-
70
- def cut_dialogue_history(history_memory, keep_last_n_words=500):
71
- if history_memory is None or len(history_memory) == 0:
72
- return history_memory
73
- tokens = history_memory.split()
74
- n_tokens = len(tokens)
75
- print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
76
- if n_tokens < keep_last_n_words:
77
- return history_memory
78
- paragraphs = history_memory.split('\n')
79
- last_n_tokens = n_tokens
80
- while last_n_tokens >= keep_last_n_words:
81
- last_n_tokens -= len(paragraphs[0].split(' '))
82
- paragraphs = paragraphs[1:]
83
- return '\n' + '\n'.join(paragraphs)
84
-
85
- def get_new_image_name(folder='chat_image', func_name="update"):
86
- this_new_uuid = str(uuid.uuid4())[:8]
87
- new_file_name = f'{func_name}_{this_new_uuid}.png'
88
- return os.path.join(folder, new_file_name)
89
-
90
- class VisualQuestionAnswering:
91
- def __init__(self, device):
92
- print(f"Initializing VisualQuestionAnswering to {device}")
93
- self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
94
- self.device = device
95
- self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
96
- self.model = BlipForQuestionAnswering.from_pretrained(
97
- "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
98
- # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
99
- # self.model = BlipForQuestionAnswering.from_pretrained(
100
- # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
101
-
102
- @prompts(name="Answer Question About The Image",
103
- description="VQA tool is useful when you need an answer for a question based on an image. "
104
- "like: what is the color of an object, how many cats in this figure, where is the child sitting, what does the cat doing, why is he laughing."
105
- "The input to this tool should be a comma separated string of two, representing the image path and the question.")
106
- def inference(self, inputs):
107
- image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
108
- raw_image = Image.open(image_path).convert('RGB')
109
- inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
110
- out = self.model.generate(**inputs)
111
- answer = self.processor.decode(out[0], skip_special_tokens=True)
112
- print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
113
- f"Output Answer: {answer}")
114
- return answer
115
-
116
- def build_chatbot_tools(load_dict):
117
- print(f"Initializing ChatBot, load_dict={load_dict}")
118
- models = {}
119
- # Load Basic Foundation Models
120
- for class_name, device in load_dict.items():
121
- models[class_name] = globals()[class_name](device=device)
122
-
123
- # Load Template Foundation Models
124
- for class_name, module in globals().items():
125
- if getattr(module, 'template_model', False):
126
- template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
127
- loaded_names = set([type(e).__name__ for e in models.values()])
128
- if template_required_names.issubset(loaded_names):
129
- models[class_name] = globals()[class_name](
130
- **{name: models[name] for name in template_required_names})
131
 
132
- tools = []
133
- for instance in models.values():
134
- for e in dir(instance):
135
- if e.startswith('inference'):
136
- func = getattr(instance, e)
137
- tools.append(Tool(name=func.name, description=func.description, func=func))
138
- return tools
139
-
140
- class ConversationBot:
141
- def __init__(self, tools, api_key=""):
142
- # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
143
- llm = OpenAIChat(model_name="gpt-4o", temperature=0.7, openai_api_key=api_key, model_kwargs={"api_version": "2020-11-07"})
144
- self.llm = llm
145
- self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
146
- self.tools = tools
147
- self.current_image = None
148
- self.point_prompt = ""
149
- self.global_prompt = ""
150
- self.agent = initialize_agent(
151
- self.tools,
152
- self.llm,
153
- agent="conversational-react-description",
154
- verbose=True,
155
- memory=self.memory,
156
- return_intermediate_steps=True,
157
- agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
158
- 'suffix': VISUAL_CHATGPT_SUFFIX}, )
159
-
160
- def constructe_intermediate_steps(self, agent_res):
161
- ans = []
162
- for action, output in agent_res:
163
- if hasattr(action, "tool_input"):
164
- use_tool = "Yes"
165
- act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
166
- else:
167
- use_tool = "No"
168
- act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
169
- act= list(map(lambda x: x.replace('\n', '<br>'), act))
170
- ans.append(act)
171
- return ans
172
-
173
- def run_text(self, text, state, aux_state):
174
- self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
175
- if self.point_prompt != "":
176
- Human_prompt = f'\nHuman: {self.point_prompt}\n'
177
- AI_prompt = 'Ok'
178
- self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
179
- self.point_prompt = ""
180
- res = self.agent({"input": text})
181
- res['output'] = res['output'].replace("\\", "/")
182
- response = re.sub('(chat_image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
183
- state = state + [(text, response)]
184
-
185
- aux_state = aux_state + [(f"User Input: {text}", None)]
186
- aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
187
- print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
188
- f"Current Memory: {self.agent.memory.buffer}\n"
189
- f"Aux state: {aux_state}\n"
190
- )
191
- return state, state, aux_state, aux_state
192
 
193
 
194
- if __name__ == '__main__':
195
- parser = argparse.ArgumentParser()
196
- parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
197
- parser.add_argument('--port', type=int, default=1015)
198
-
199
- args = parser.parse_args()
200
- load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
201
- tools = build_chatbot_tools(load_dict)
202
- bot = ConversationBot(tools)
203
- with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
204
- with gr.Row():
205
- chatbot = gr.Chatbot(elem_id="chatbot", label="CATchat").style(height=1000,scale=0.5)
206
- auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  state = gr.State([])
 
 
 
 
 
 
 
 
 
208
  aux_state = gr.State([])
209
- with gr.Row():
210
- with gr.Column(scale=0.7):
211
- txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
212
- container=False)
213
- with gr.Column(scale=0.15, min_width=0):
214
- clear = gr.Button("Clear")
215
- with gr.Column(scale=0.15, min_width=0):
216
- btn = gr.UploadButton("Upload", file_types=["image"])
217
-
218
- txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
219
- txt.submit(lambda: "", None, txt)
220
- btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
221
- clear.click(bot.memory.clear)
222
- clear.click(lambda: [], None, chatbot)
223
- clear.click(lambda: [], None, auxwindow)
224
- clear.click(lambda: [], None, state)
225
- clear.click(lambda: [], None, aux_state)
226
- demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from math import inf
 
3
  import os
4
+ import base64
5
+ import json
6
  import gradio as gr
7
+ import numpy as np
8
+ from gradio import processing_utils
9
+ import requests
10
+ from packaging import version
11
+ from PIL import Image, ImageDraw
12
+ import functools
13
+ import emoji
14
+ from langchain.llms.openai import OpenAI
15
+ from caption_anything.model import CaptionAnything
16
+ from caption_anything.utils.image_editing_utils import create_bubble_frame
17
+ from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
18
+ from caption_anything.utils.parser import parse_augment
19
+ from caption_anything.captioner import build_captioner
20
+ from caption_anything.text_refiner import build_text_refiner
21
+ from caption_anything.segmenter import build_segmenter
22
+ from chatbox import ConversationBot, build_chatbot_tools, get_new_image_name
23
+ from segment_anything import sam_model_registry
24
+ import easyocr
25
  import re
26
+ import edge_tts
27
+
28
+ # import tts
29
+
30
+ ###############################################################################
31
+ ############# this part is for 3D generate #############
32
+ ###############################################################################
33
+
34
+
35
+ # import spaces #
36
+
37
+ import os
38
+ # import uuid
39
+ # from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
40
+ # from diffusers.utils import export_to_video
41
+ # from safetensors.torch import load_file
42
+ #from diffusers.models.modeling_outputs import Transformer2DModelOutput
43
+
44
+
45
+ import random
46
  import uuid
47
+ import json
48
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
49
+
50
+
51
 
52
+
53
+ import imageio
54
+ import numpy as np
 
55
  import torch
56
+ import rembg
57
+ from PIL import Image
58
+ from torchvision.transforms import v2
59
+ from pytorch_lightning import seed_everything
60
+ from omegaconf import OmegaConf
61
+ from einops import rearrange, repeat
62
+ from tqdm import tqdm
63
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
64
+
65
+ from src.utils.train_util import instantiate_from_config
66
+ from src.utils.camera_util import (
67
+ FOV_to_intrinsics,
68
+ get_zero123plus_input_cameras,
69
+ get_circular_camera_poses,
70
+ )
71
+ from src.utils.mesh_util import save_obj, save_glb
72
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
73
+
74
+ import tempfile
75
+ from functools import partial
76
+
77
+ from huggingface_hub import hf_hub_download
78
+
79
+
80
+
81
+
82
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
83
+ """
84
+ Get the rendering camera parameters.
85
+ """
86
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
87
+ if is_flexicubes:
88
+ cameras = torch.linalg.inv(c2ws)
89
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
90
+ else:
91
+ extrinsics = c2ws.flatten(-2)
92
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
93
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
94
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
95
+ return cameras
96
+
97
+
98
+ def images_to_video(images, output_path, fps=30):
99
+ # images: (N, C, H, W)
100
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
101
+ frames = []
102
+ for i in range(images.shape[0]):
103
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
104
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
105
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
106
+ assert frame.min() >= 0 and frame.max() <= 255, \
107
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
108
+ frames.append(frame)
109
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
110
+
111
+
112
+ ###############################################################################
113
+ # Configuration.
114
+ ###############################################################################
115
+
116
+ import shutil
117
+
118
+ def find_cuda():
119
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
120
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
121
+
122
+ if cuda_home and os.path.exists(cuda_home):
123
+ return cuda_home
124
+
125
+ # Search for the nvcc executable in the system's PATH
126
+ nvcc_path = shutil.which('nvcc')
127
+
128
+ if nvcc_path:
129
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
130
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
131
+ return cuda_path
132
+
133
+ return None
134
+
135
+ cuda_path = find_cuda()
136
+
137
+ if cuda_path:
138
+ print(f"CUDA installation found at: {cuda_path}")
139
+ else:
140
+ print("CUDA installation not found")
141
+
142
+ config_path = 'configs/instant-nerf-base.yaml'
143
+ config = OmegaConf.load(config_path)
144
+ config_name = os.path.basename(config_path).replace('.yaml', '')
145
+ model_config = config.model_config
146
+ infer_config = config.infer_config
147
+
148
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
149
+
150
+ device = torch.device('cuda')
151
+
152
+ # load diffusion model
153
+ print('Loading diffusion model ...')
154
+ pipeline = DiffusionPipeline.from_pretrained(
155
+ "sudo-ai/zero123plus-v1.2",
156
+ custom_pipeline="zero123plus",
157
+ torch_dtype=torch.float16,
158
+ )
159
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
160
+ pipeline.scheduler.config, timestep_spacing='trailing'
161
+ )
162
+
163
+ # load custom white-background UNet
164
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
165
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
166
+ pipeline.unet.load_state_dict(state_dict, strict=True)
167
+
168
+ pipeline = pipeline.to(device)
169
+
170
+ # load reconstruction model
171
+ print('Loading reconstruction model ...')
172
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_nerf_base.ckpt", repo_type="model")
173
+ model0 = instantiate_from_config(model_config)
174
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
175
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
176
+ model0.load_state_dict(state_dict, strict=True)
177
+
178
+ model0 = model0.to(device)
179
+
180
+ print('Loading Finished!')
181
+
182
+
183
+ def check_input_image(input_image):
184
+ if input_image is None:
185
+ raise gr.Error("No image uploaded!")
186
+ image = None
187
+ else:
188
+ image = Image.open(input_image)
189
+ return image
190
+
191
+ def preprocess(input_image, do_remove_background):
192
+
193
+ rembg_session = rembg.new_session() if do_remove_background else None
194
+
195
+ if do_remove_background:
196
+ input_image = remove_background(input_image, rembg_session)
197
+ input_image = resize_foreground(input_image, 0.85)
198
+
199
+ return input_image
200
+
201
+
202
+ # @spaces.GPU
203
+ def generate_mvs(input_image, sample_steps, sample_seed):
204
+
205
+ seed_everything(sample_seed)
206
+
207
+ # sampling
208
+ z123_image = pipeline(
209
+ input_image,
210
+ num_inference_steps=sample_steps
211
+ ).images[0]
212
+
213
+ show_image = np.asarray(z123_image, dtype=np.uint8)
214
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
215
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
216
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
217
+ show_image = Image.fromarray(show_image.numpy())
218
+
219
+ return z123_image, show_image
220
+
221
+
222
+ # @spaces.GPU
223
+ def make3d(images):
224
+
225
+ global model0
226
+ if IS_FLEXICUBES:
227
+ model0.init_flexicubes_geometry(device)
228
+ model0 = model0.eval()
229
+
230
+ images = np.asarray(images, dtype=np.float32) / 255.0
231
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
232
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
233
+
234
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
235
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
236
+
237
+ images = images.unsqueeze(0).to(device)
238
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
239
+
240
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
241
+ print(mesh_fpath)
242
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
243
+ mesh_dirname = os.path.dirname(mesh_fpath)
244
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
245
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
246
+
247
+ with torch.no_grad():
248
+ # get triplane
249
+ planes = model0.forward_planes(images, input_cameras)
250
+
251
+ # # get video
252
+ # chunk_size = 20 if IS_FLEXICUBES else 1
253
+ # render_size = 384
254
+
255
+ # frames = []
256
+ # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
257
+ # if IS_FLEXICUBES:
258
+ # frame = model.forward_geometry(
259
+ # planes,
260
+ # render_cameras[:, i:i+chunk_size],
261
+ # render_size=render_size,
262
+ # )['img']
263
+ # else:
264
+ # frame = model.synthesizer(
265
+ # planes,
266
+ # cameras=render_cameras[:, i:i+chunk_size],
267
+ # render_size=render_size,
268
+ # )['images_rgb']
269
+ # frames.append(frame)
270
+ # frames = torch.cat(frames, dim=1)
271
+
272
+ # images_to_video(
273
+ # frames[0],
274
+ # video_fpath,
275
+ # fps=30,
276
+ # )
277
+
278
+ # print(f"Video saved to {video_fpath}")
279
+
280
+ # get mesh
281
+ mesh_out = model0.extract_mesh(
282
+ planes,
283
+ use_texture_map=False,
284
+ **infer_config,
285
+ )
286
+
287
+ vertices, faces, vertex_colors = mesh_out
288
+ vertices = vertices[:, [1, 2, 0]]
289
+
290
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
291
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
292
+
293
+ print(f"Mesh saved to {mesh_fpath}")
294
+
295
+ return mesh_fpath, mesh_glb_fpath
296
+
297
+
298
+ ###############################################################################
299
+ ############# above part is for 3D generate #############
300
+ ###############################################################################
301
+
302
+
303
+ ###############################################################################
304
+ ############# this part is for text to image #############
305
+ ###############################################################################
306
+
307
+ # Use environment variables for flexibility
308
+ MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
309
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
310
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
311
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
312
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
313
+
314
+ # Determine device and load model outside of function for efficiency
315
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
316
+ pipe = StableDiffusionXLPipeline.from_pretrained(
317
+ MODEL_ID,
318
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
319
+ use_safetensors=True,
320
+ add_watermarker=False,
321
+ ).to(device)
322
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
323
 
324
+ # Torch compile for potential speedup (experimental)
325
+ if USE_TORCH_COMPILE:
326
+ pipe.compile()
327
 
328
+ # CPU offloading for larger RAM capacity (experimental)
329
+ if ENABLE_CPU_OFFLOAD:
330
+ pipe.enable_model_cpu_offload()
331
 
332
+ MAX_SEED = np.iinfo(np.int32).max
333
 
334
+ def save_image(img):
335
+ unique_name = str(uuid.uuid4()) + ".png"
336
+ img.save(unique_name)
337
+ return unique_name
338
 
339
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
340
+ if randomize_seed:
341
+ seed = random.randint(0, MAX_SEED)
342
+ return seed
343
 
344
+ # @spaces.GPU(duration=30, queue=False)
345
+ def generate(
346
+ prompt: str,
347
+ negative_prompt: str = "",
348
+ use_negative_prompt: bool = False,
349
+ seed: int = 1,
350
+ width: int = 200,
351
+ height: int = 200,
352
+ guidance_scale: float = 3,
353
+ num_inference_steps: int = 30,
354
+ randomize_seed: bool = False,
355
+ use_resolution_binning: bool = True,
356
+ num_images: int = 4, # Number of images to generate
357
+ progress=gr.Progress(track_tqdm=True),
358
+ ):
359
+ seed = int(randomize_seed_fn(seed, randomize_seed))
360
+ generator = torch.Generator(device=device).manual_seed(seed)
361
 
362
+ # Improved options handling
363
+ options = {
364
+ "prompt": [prompt] * num_images,
365
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
366
+ "width": width,
367
+ "height": height,
368
+ "guidance_scale": guidance_scale,
369
+ "num_inference_steps": num_inference_steps,
370
+ "generator": generator,
371
+ "output_type": "pil",
372
+ }
373
 
374
+ # Use resolution binning for faster generation with less VRAM usage
375
+ # if use_resolution_binning:
376
+ # options["use_resolution_binning"] = True
 
377
 
378
+ # Generate images potentially in batches
379
+ images = []
380
+ for i in range(0, num_images, BATCH_SIZE):
381
+ batch_options = options.copy()
382
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
383
+ if "negative_prompt" in batch_options:
384
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
385
+ images.extend(pipe(**batch_options).images)
386
 
387
+ image_paths = [save_image(img) for img in images]
388
+ return image_paths, seed
389
 
390
+ examples = [
391
+ "a cat eating a piece of cheese",
392
+ "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
393
+ "Ironman VS Hulk, ultrarealistic",
394
+ "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
395
+ "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
396
+ "Kids going to school, Anime style"
397
+ ]
398
+
399
+
400
+
401
+
402
+ ###############################################################################
403
+ ############# above part is for text to image #############
404
+ ###############################################################################
405
+
406
+
407
+ css = """
408
+ #warning {background-color: #FFCCCB}
409
+ .chatbot {
410
+ padding: 0 !important;
411
+ margin: 0 !important;
412
+ }
413
+ """
414
+ filtered_language_dict = {
415
+ 'English': 'en-US-JennyNeural',
416
+ 'Chinese': 'zh-CN-XiaoxiaoNeural',
417
+ 'French': 'fr-FR-DeniseNeural',
418
+ 'Spanish': 'es-MX-DaliaNeural',
419
+ 'Arabic': 'ar-SA-ZariyahNeural',
420
+ 'Portuguese': 'pt-BR-FranciscaNeural',
421
+ 'Cantonese': 'zh-HK-HiuGaaiNeural'
422
+ }
423
+
424
+ focus_map = {
425
+ "CFV-D":0,
426
+ "CFV-DA":1,
427
+ "CFV-DAI":2,
428
+ "PFV-DDA":3
429
+ }
430
+
431
+ '''
432
+ prompt_list = [
433
+ 'Wiki_caption: {Wiki_caption}, you have to generate a caption according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
434
+ 'Wiki_caption: {Wiki_caption}, you have to select sentences from wiki caption that describe the surrounding objects that may be associated with the picture object. Around {length} words of {sentiment} sentiment in {language}.',
435
+ 'Wiki_caption: {Wiki_caption}. You have to choose sentences from the wiki caption that describe unrelated objects to the image. Around {length} words of {sentiment} sentiment in {language}.',
436
+ 'Wiki_caption: {Wiki_caption}. You have to choose sentences from the wiki caption that describe unrelated objects to the image. Around {length} words of {sentiment} sentiment in {language}.'
437
+ ]
438
+
439
+ prompt_list = [
440
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the object but does not include analysis)as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
441
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
442
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
443
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and the objects that may be related to the selected object and list one fact of selected object, one fact of related object and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.'
444
+ ]
445
+ '''
446
+ prompt_list = [
447
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the selected object but does not include analysis)as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
448
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
449
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
450
+ 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and the objects that may be related to the selected object and list one fact of selected object, one fact of related object and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.'
451
+ ]
452
+
453
+
454
+ gpt_state = 0
455
+ VOICE = "en-GB-SoniaNeural"
456
+ article = """
457
+ <div style='margin:20px auto;'>
458
+ <p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p>
459
+ </div>
460
  """
461
 
462
+ args = parse_augment()
463
+ args.segmenter = "huge"
464
+ args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
465
+ args.clip_filter = True
466
+ if args.segmenter_checkpoint is None:
467
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
468
+ else:
469
+ segmenter_checkpoint = args.segmenter_checkpoint
470
+
471
+ shared_captioner = build_captioner(args.captioner, args.device, args)
472
+ shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
473
+ ocr_lang = ["ch_tra", "en"]
474
+ shared_ocr_reader = easyocr.Reader(ocr_lang)
475
+ tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
476
+ shared_chatbot_tools = build_chatbot_tools(tools_dict)
477
+
478
+
479
+ class ImageSketcher(gr.Image):
480
+ """
481
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
482
+ """
483
+
484
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
485
+
486
+ def __init__(self, **kwargs):
487
+ super().__init__(tool="sketch", **kwargs)
488
+
489
+ def preprocess(self, x):
490
+ if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
491
+ assert isinstance(x, dict)
492
+ if x['mask'] is None:
493
+ decode_image = processing_utils.decode_base64_to_image(x['image'])
494
+ width, height = decode_image.size
495
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
496
+ mask[..., -1] = 255
497
+ mask = self.postprocess(mask)
498
+ x['mask'] = mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
+ return super().preprocess(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
 
503
+ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
504
+ session_id=None):
505
+ segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
506
+ captioner = captioner
507
+ if session_id is not None:
508
+ print('Init caption anything for session {}'.format(session_id))
509
+ return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
510
+
511
+
512
+ def validate_api_key(api_key):
513
+ api_key = str(api_key).strip()
514
+ print(api_key)
515
+ try:
516
+ test_llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
517
+ response = test_llm("Test API call")
518
+ print(response)
519
+ return True
520
+ except Exception as e:
521
+ print(f"API key validation failed: {e}")
522
+ return False
523
+
524
+
525
+ def init_openai_api_key(api_key=""):
526
+ text_refiner = None
527
+ visual_chatgpt = None
528
+ if api_key and len(api_key) > 30:
529
+ print(api_key)
530
+ if validate_api_key(api_key):
531
+ try:
532
+ text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
533
+ assert len(text_refiner.llm('hi')) > 0 # test
534
+ visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
535
+ except Exception as e:
536
+ print(f"Error initializing TextRefiner or ConversationBot: {e}")
537
+ text_refiner = None
538
+ visual_chatgpt = None
539
+ else:
540
+ print("Invalid API key.")
541
+ else:
542
+ print("API key is too short.")
543
+ print(text_refiner)
544
+ openai_available = text_refiner is not None
545
+ if openai_available:
546
+
547
+ global gpt_state
548
+ gpt_state=1
549
+ # return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*3
550
+ return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*2
551
+ else:
552
+ gpt_state=0
553
+ # return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*3
554
+ return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*2
555
+
556
+ def init_wo_openai_api_key():
557
+ global gpt_state
558
+ gpt_state=0
559
+ # return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*3
560
+ return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*2
561
+
562
+ def get_click_prompt(chat_input, click_state, click_mode):
563
+ inputs = json.loads(chat_input)
564
+ if click_mode == 'Continuous':
565
+ points = click_state[0]
566
+ labels = click_state[1]
567
+ for input in inputs:
568
+ points.append(input[:2])
569
+ labels.append(input[2])
570
+ elif click_mode == 'Single':
571
+ points = []
572
+ labels = []
573
+ for input in inputs:
574
+ points.append(input[:2])
575
+ labels.append(input[2])
576
+ click_state[0] = points
577
+ click_state[1] = labels
578
+ else:
579
+ raise NotImplementedError
580
+
581
+ prompt = {
582
+ "prompt_type": ["click"],
583
+ "input_point": click_state[0],
584
+ "input_label": click_state[1],
585
+ "multimask_output": "True",
586
+ }
587
+ return prompt
588
+
589
+
590
+ def update_click_state(click_state, caption, click_mode):
591
+ if click_mode == 'Continuous':
592
+ click_state[2].append(caption)
593
+ elif click_mode == 'Single':
594
+ click_state[2] = [caption]
595
+ else:
596
+ raise NotImplementedError
597
+
598
+ async def chat_input_callback(*args):
599
+ visual_chatgpt, chat_input, click_state, state, aux_state ,language , autoplay = args
600
+ if visual_chatgpt is not None:
601
+ state, _, aux_state, _ = visual_chatgpt.run_text(chat_input, state, aux_state)
602
+ last_text, last_response = state[-1]
603
+ print("last response",last_response)
604
+ if autoplay:
605
+ audio = await texttospeech(last_response,language,autoplay)
606
+ else:
607
+ audio=None
608
+ return state, state, aux_state, audio
609
+ else:
610
+ response = "Text refiner is not initilzed, please input openai api key."
611
+ state = state + [(chat_input, response)]
612
+ audio = await texttospeech(response,language,autoplay)
613
+ return state, state, None, audio
614
+
615
+
616
+
617
+ def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
618
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
619
+ image_input, mask = image_input['image'], image_input['mask']
620
+
621
+ click_state = [[], [], []]
622
+ image_input = image_resize(image_input, res=1024)
623
+
624
+ model = build_caption_anything_with_models(
625
+ args,
626
+ api_key="",
627
+ captioner=shared_captioner,
628
+ sam_model=shared_sam_model,
629
+ ocr_reader=shared_ocr_reader,
630
+ session_id=iface.app_id
631
+ )
632
+ model.segmenter.set_image(image_input)
633
+ image_embedding = model.image_embedding
634
+ original_size = model.original_size
635
+ input_size = model.input_size
636
+
637
+ if visual_chatgpt is not None:
638
+ print('upload_callback: add caption to chatGPT memory')
639
+ new_image_path = get_new_image_name('chat_image', func_name='upload')
640
+ image_input.save(new_image_path)
641
+ visual_chatgpt.current_image = new_image_path
642
+ img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
643
+ Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
644
+ AI_prompt = "Received."
645
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
646
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
647
+ parsed_data = get_image_gpt(openai_api_key, new_image_path,"Please provide the name, artist, year of creation, and material used for this painting. Return the information in dictionary format without any newline characters. If any information is unavailable, return \"None\" for that field. Format as follows: { \"name\": \"Name of the painting\",\"artist\": \"Name of the artist\", \"year\": \"Year of creation\", \"material\": \"Material used in the painting\" }.")
648
+ parsed_data = json.loads(parsed_data.replace("'", "\""))
649
+ name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["material"]
650
+ # artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
651
+ paragraph = get_image_gpt(openai_api_key, new_image_path,f"What's going on in this picture? in {language}")
652
+
653
+ state = [
654
+ (
655
+ None,
656
+ f"🤖 Hi, I am EyeSee. Let's explore this painting {name} together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information."
657
+ )
658
+ ]
659
+
660
+ return state, state, image_input, click_state, image_input, image_input, image_input, image_embedding, \
661
+ original_size, input_size, f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Material: {material}",f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Material: {material}",paragraph,artist
662
+
663
+
664
+
665
+
666
+ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
667
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
668
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData):
669
+ click_index = evt.index
670
+
671
+ if point_prompt == 'Positive':
672
+ coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
673
+ else:
674
+ coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
675
+
676
+ prompt = get_click_prompt(coordinate, click_state, click_mode)
677
+ input_points = prompt['input_point']
678
+ input_labels = prompt['input_label']
679
+
680
+ controls = {'length': length,
681
+ 'sentiment': sentiment,
682
+ 'factuality': factuality,
683
+ 'language': language}
684
+
685
+ model = build_caption_anything_with_models(
686
+ args,
687
+ api_key="",
688
+ captioner=shared_captioner,
689
+ sam_model=shared_sam_model,
690
+ ocr_reader=shared_ocr_reader,
691
+ text_refiner=text_refiner,
692
+ session_id=iface.app_id
693
+ )
694
+
695
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
696
+
697
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
698
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
699
+
700
+ state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
701
+ update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
702
+ text = out['generated_captions']['raw_caption']
703
+ input_mask = np.array(out['mask'].convert('P'))
704
+ image_input_nobackground = mask_painter(np.array(image_input), input_mask,background_alpha=0)
705
+ image_input_withbackground=mask_painter(np.array(image_input), input_mask)
706
+
707
+ click_index_state = click_index
708
+ input_mask_state = input_mask
709
+ input_points_state = input_points
710
+ input_labels_state = input_labels
711
+ out_state = out
712
+
713
+ if visual_chatgpt is not None:
714
+ print('inference_click: add caption to chatGPT memory')
715
+ new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
716
+ Image.open(out["crop_save_path"]).save(new_crop_save_path)
717
+ point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
718
+ visual_chatgpt.point_prompt = point_prompt
719
+
720
+
721
+ print("new crop save",new_crop_save_path)
722
+
723
+ yield state, state, click_state, image_input_nobackground, image_input_withbackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
724
+
725
+
726
+
727
+
728
+
729
+ async def submit_caption(state, text_refiner, length, sentiment, factuality, language,
730
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
731
+ autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
732
+ print("state",state)
733
+
734
+ click_index = click_index_state
735
+
736
+ # if pre_click_index==click_index:
737
+ # click_index = (click_index[0] - 1, click_index[1] - 1)
738
+ # pre_click_index = click_index
739
+ # else:
740
+ # pre_click_index = click_index
741
+ print("click_index",click_index)
742
+ print("input_points_state",input_points_state)
743
+ print("input_labels_state",input_labels_state)
744
+
745
+ prompt=generate_prompt(paragraph,focus_type,length,sentiment,factuality,language)
746
+
747
+ print("Prompt:", prompt)
748
+ print("click",click_index)
749
+
750
+ # image_input = create_bubble_frame(np.array(image_input), generated_caption, click_index, input_mask,
751
+ # input_points=input_points, input_labels=input_labels)
752
+
753
+
754
+ if not args.disable_gpt and text_refiner:
755
+ print("new crop save",new_crop_save_path)
756
+ focus_info=get_image_gpt(openai_api_key,new_crop_save_path,prompt)
757
+ if focus_info.startswith('"') and focus_info.endswith('"'):
758
+ focus_info=focus_info[1:-1]
759
+ focus_info=focus_info.replace('#', '')
760
+ # state = state + [(None, f"Wiki: {paragraph}")]
761
+ state = state + [(None, f"{focus_info}")]
762
+ print("new_cap",focus_info)
763
+ read_info = re.sub(r'[#[\]!*]','',focus_info)
764
+ read_info = emoji.replace_emoji(read_info,replace="")
765
+ print("read info",read_info)
766
+
767
+ # refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask,
768
+ # input_points=input_points, input_labels=input_labels)
769
+ try:
770
+ audio_output = await texttospeech(read_info, language, autoplay)
771
+ print("done")
772
+ # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
773
+ return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
774
+
775
+ except Exception as e:
776
+ state = state + [(None, f"Error during TTS prediction: {str(e)}")]
777
+ print(f"Error during TTS prediction: {str(e)}")
778
+ # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
779
+ return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
780
+
781
+ else:
782
+ try:
783
+ audio_output = await texttospeech(focus_info, language, autoplay)
784
+ # waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
785
+ # return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
786
+ return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
787
+
788
+ except Exception as e:
789
+ state = state + [(None, f"Error during TTS prediction: {str(e)}")]
790
+ print(f"Error during TTS prediction: {str(e)}")
791
+ return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
792
+
793
+ def generate_prompt(focus_type, paragraph,length, sentiment, factuality, language):
794
+
795
+ mapped_value = focus_map.get(focus_type, -1)
796
+
797
+ controls = {
798
+ 'length': length,
799
+ 'sentiment': sentiment,
800
+ 'factuality': factuality,
801
+ 'language': language
802
+ }
803
+
804
+ if mapped_value != -1:
805
+ prompt = prompt_list[mapped_value].format(
806
+ Wiki_caption=paragraph,
807
+ length=controls['length'],
808
+ sentiment=controls['sentiment'],
809
+ language=controls['language']
810
+ )
811
+ else:
812
+ prompt = "Invalid focus type."
813
+
814
+ if controls['factuality'] == "Imagination":
815
+ prompt += " Assuming that I am someone who has viewed a lot of art and has a lot of experience viewing art. Explain artistic features (composition, color, style, or use of light) and discuss the symbolism of the content and its influence on later artistic movements."
816
+
817
+ return prompt
818
+
819
+
820
+ def encode_image(image_path):
821
+ with open(image_path, "rb") as image_file:
822
+ return base64.b64encode(image_file.read()).decode('utf-8')
823
+
824
+ def get_image_gpt(api_key, image_path,prompt,enable_wiki=None):
825
+ # Getting the base64 string
826
+ base64_image = encode_image(image_path)
827
+
828
+
829
+
830
+ headers = {
831
+ "Content-Type": "application/json",
832
+ "Authorization": f"Bearer {api_key}"
833
+ }
834
+
835
+ prompt_text = prompt
836
+
837
+ payload = {
838
+ "model": "gpt-4o",
839
+ "messages": [
840
+ {
841
+ "role": "user",
842
+ "content": [
843
+ {
844
+ "type": "text",
845
+ "text": prompt_text
846
+ },
847
+ {
848
+ "type": "image_url",
849
+ "image_url": {
850
+ "url": f"data:image/jpeg;base64,{base64_image}"
851
+ }
852
+ }
853
+ ]
854
+ }
855
+ ],
856
+ "max_tokens": 300
857
+ }
858
+
859
+ # Sending the request to the OpenAI API
860
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
861
+ result = response.json()
862
+ print(result)
863
+ content = result['choices'][0]['message']['content']
864
+ # Assume the model returns a valid JSON string in 'content'
865
+ try:
866
+ return content
867
+ except json.JSONDecodeError:
868
+ return {"error": "Failed to parse model output"}
869
+
870
+
871
+
872
+
873
+ def get_sketch_prompt(mask: Image.Image):
874
+ """
875
+ Get the prompt for the sketcher.
876
+ TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
877
+ """
878
+
879
+ mask = np.asarray(mask)[..., 0]
880
+
881
+ # Get the bounding box of the sketch
882
+ y, x = np.where(mask != 0)
883
+ x1, y1 = np.min(x), np.min(y)
884
+ x2, y2 = np.max(x), np.max(y)
885
+
886
+ prompt = {
887
+ 'prompt_type': ['box'],
888
+ 'input_boxes': [
889
+ [x1, y1, x2, y2]
890
+ ]
891
+ }
892
+
893
+ return prompt
894
+
895
+ submit_traj=0
896
+
897
+ async def inference_traject(origin_image,sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
898
+ original_size, input_size, text_refiner,focus_type,paragraph,openai_api_key,autoplay,trace_type):
899
+ image_input, mask = sketcher_image['image'], sketcher_image['mask']
900
+
901
+ crop_save_path=""
902
+
903
+ prompt = get_sketch_prompt(mask)
904
+ boxes = prompt['input_boxes']
905
+ boxes = boxes[0]
906
+ global submit_traj
907
+ submit_traj=1
908
+
909
+ controls = {'length': length,
910
+ 'sentiment': sentiment,
911
+ 'factuality': factuality,
912
+ 'language': language}
913
+
914
+ model = build_caption_anything_with_models(
915
+ args,
916
+ api_key="",
917
+ captioner=shared_captioner,
918
+ sam_model=shared_sam_model,
919
+ ocr_reader=shared_ocr_reader,
920
+ text_refiner=text_refiner,
921
+ session_id=iface.app_id
922
+ )
923
+
924
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
925
+
926
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
927
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki,verbose=True)[0]
928
+
929
+ print(trace_type)
930
+
931
+ if trace_type=="Trace+Seg":
932
+ input_mask = np.array(out['mask'].convert('P'))
933
+ image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0 )
934
+ crop_save_path=out['crop_save_path']
935
+
936
+ else:
937
+ image_input = Image.fromarray(np.array(origin_image))
938
+ draw = ImageDraw.Draw(image_input)
939
+ draw.rectangle(boxes, outline='red', width=2)
940
+ cropped_image = origin_image.crop(boxes)
941
+ cropped_image.save('temp.png')
942
+ crop_save_path='temp.png'
943
+
944
+ print("crop_svae_path",out['crop_save_path'])
945
+
946
+ # Update components and states
947
+ state.append((f'Box: {boxes}', None))
948
+
949
+ # fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
950
+ # image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
951
+
952
+ prompt=generate_prompt(focus_type, paragraph, length, sentiment, factuality, language)
953
+ width, height = sketcher_image['image'].size
954
+ sketcher_image['mask'] = np.zeros((height, width, 4), dtype=np.uint8)
955
+ sketcher_image['mask'][..., -1] = 255
956
+ sketcher_image['image']=image_input
957
+
958
+
959
+ if not args.disable_gpt and text_refiner:
960
+ focus_info=get_image_gpt(openai_api_key,crop_save_path,prompt)
961
+ if focus_info.startswith('"') and focus_info.endswith('"'):
962
+ focus_info=focus_info[1:-1]
963
+ focus_info=focus_info.replace('#', '')
964
+ state = state + [(None, f"{focus_info}")]
965
+ print("new_cap",focus_info)
966
+ read_info = re.sub(r'[#[\]!*]','',focus_info)
967
+ read_info = emoji.replace_emoji(read_info,replace="")
968
+ print("read info",read_info)
969
+
970
+ # refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask,
971
+ # input_points=input_points, input_labels=input_labels)
972
+ try:
973
+ audio_output = await texttospeech(read_info, language,autoplay)
974
+ # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
975
+ return state, state,image_input,audio_output
976
+
977
+
978
+ except Exception as e:
979
+ state = state + [(None, f"Error during TTS prediction: {str(e)}")]
980
+ print(f"Error during TTS prediction: {str(e)}")
981
+ # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
982
+ return state, state, image_input,audio_output
983
+
984
+
985
+ else:
986
+ try:
987
+ audio_output = await texttospeech(focus_info, language, autoplay)
988
+ # waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
989
+ # return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
990
+ return state, state, image_input,audio_output
991
+
992
+
993
+ except Exception as e:
994
+ state = state + [(None, f"Error during TTS prediction: {str(e)}")]
995
+ print(f"Error during TTS prediction: {str(e)}")
996
+ return state, state, image_input,audio_output
997
+
998
+
999
+ def clear_chat_memory(visual_chatgpt, keep_global=False):
1000
+ if visual_chatgpt is not None:
1001
+ visual_chatgpt.memory.clear()
1002
+ visual_chatgpt.point_prompt = ""
1003
+ if keep_global:
1004
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
1005
+ else:
1006
+ visual_chatgpt.current_image = None
1007
+ visual_chatgpt.global_prompt = ""
1008
+
1009
+
1010
+ def export_chat_log(chat_state, paragraph, liked, disliked):
1011
+ try:
1012
+ if not chat_state:
1013
+ return None
1014
+ chat_log = f"Image Description: {paragraph}\n\n"
1015
+ for entry in chat_state:
1016
+ user_message, bot_response = entry
1017
+ if user_message and bot_response:
1018
+ chat_log += f"User: {user_message}\nBot: {bot_response}\n"
1019
+ elif user_message:
1020
+ chat_log += f"User: {user_message}\n"
1021
+ elif bot_response:
1022
+ chat_log += f"Bot: {bot_response}\n"
1023
+
1024
+ # 添加 liked 和 disliked 信息
1025
+ chat_log += "\nLiked Responses:\n"
1026
+ for response in liked:
1027
+ chat_log += f"{response}\n"
1028
+
1029
+ chat_log += "\nDisliked Responses:\n"
1030
+ for response in disliked:
1031
+ chat_log += f"{response}\n"
1032
+
1033
+ print("export log...")
1034
+ print("chat_log", chat_log)
1035
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
1036
+ temp_file.write(chat_log.encode('utf-8'))
1037
+ temp_file_path = temp_file.name
1038
+ print(temp_file_path)
1039
+ return temp_file_path
1040
+ except Exception as e:
1041
+ print(f"An error occurred while exporting the chat log: {e}")
1042
+ return None
1043
+
1044
+
1045
+
1046
+ async def cap_everything(paragraph, visual_chatgpt,language,autoplay):
1047
+
1048
+ # state = state + [(None, f"Caption Everything: {paragraph}")]
1049
+ Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
1050
+ AI_prompt = "Received."
1051
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
1052
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
1053
+ # waveform_visual, audio_output=tts.predict(paragraph, input_language, input_audio, input_mic, use_mic, agree)
1054
+ audio_output=await texttospeech(paragraph,language,autoplay)
1055
+ return paragraph,audio_output
1056
+
1057
+ def cap_everything_withoutsound(image_input, visual_chatgpt, text_refiner,paragraph):
1058
+
1059
+ model = build_caption_anything_with_models(
1060
+ args,
1061
+ api_key="",
1062
+ captioner=shared_captioner,
1063
+ sam_model=shared_sam_model,
1064
+ ocr_reader=shared_ocr_reader,
1065
+ text_refiner=text_refiner,
1066
+ session_id=iface.app_id
1067
+ )
1068
+ paragraph = model.inference_cap_everything(image_input, verbose=True)
1069
+ # state = state + [(None, f"Caption Everything: {paragraph}")]
1070
+ Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
1071
+ AI_prompt = "Received."
1072
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
1073
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
1074
+ return paragraph
1075
+
1076
+ def handle_liked(state,like_res):
1077
+ if state:
1078
+ like_res.append(state[-1][1])
1079
+ print(f"Last response recorded: {state[-1][1]}")
1080
+ else:
1081
+ print("No response to record.")
1082
+ state = state + [(None, f"Liked Received 👍")]
1083
+ return state,like_res
1084
+
1085
+ def handle_disliked(state,dislike_res):
1086
+ if state:
1087
+ dislike_res.append(state[-1][1])
1088
+ print(f"Last response recorded: {state[-1][1]}")
1089
+ else:
1090
+ print("No response to record.")
1091
+ state = state + [(None, f"Disliked Received 🥹")]
1092
+ return state,dislike_res
1093
+
1094
+
1095
+ def get_style():
1096
+ current_version = version.parse(gr.__version__)
1097
+ if current_version <= version.parse('3.24.1'):
1098
+ style = '''
1099
+ #image_sketcher{min-height:500px}
1100
+ #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
1101
+ #image_upload{min-height:500px}
1102
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
1103
+ .custom-language {
1104
+ width: 20%;
1105
+ }
1106
+
1107
+ .custom-autoplay {
1108
+ width: 40%;
1109
+ }
1110
+
1111
+ .custom-output {
1112
+ width: 30%;
1113
+ }
1114
+
1115
+ '''
1116
+ elif current_version <= version.parse('3.27'):
1117
+ style = '''
1118
+ #image_sketcher{min-height:500px}
1119
+ #image_upload{min-height:500px}
1120
+ .custom-language {
1121
+ width: 20%;
1122
+ }
1123
+
1124
+ .custom-autoplay {
1125
+ width: 40%;
1126
+ }
1127
+
1128
+ .custom-output {
1129
+ width: 30%;
1130
+ }
1131
+ '''
1132
+ else:
1133
+ style = None
1134
+
1135
+ return style
1136
+
1137
+ # def handle_like_dislike(like_data, like_state, dislike_state):
1138
+ # if like_data.liked:
1139
+ # if like_data.index not in like_state:
1140
+ # like_state.append(like_data.index)
1141
+ # message = f"Liked: {like_data.value} at index {like_data.index}"
1142
+ # else:
1143
+ # message = "You already liked this item"
1144
+ # else:
1145
+ # if like_data.index not in dislike_state:
1146
+ # dislike_state.append(like_data.index)
1147
+ # message = f"Disliked: {like_data.value} at index {like_data.index}"
1148
+ # else:
1149
+ # message = "You already disliked this item"
1150
+
1151
+ # return like_state, dislike_state
1152
+
1153
+ async def texttospeech(text, language, autoplay):
1154
+ try:
1155
+ if autoplay:
1156
+ voice = filtered_language_dict[language]
1157
+ communicate = edge_tts.Communicate(text, voice)
1158
+ file_path = "output.wav"
1159
+ await communicate.save(file_path)
1160
+ with open(file_path, "rb") as audio_file:
1161
+ audio_bytes = BytesIO(audio_file.read())
1162
+ audio = base64.b64encode(audio_bytes.read()).decode("utf-8")
1163
+ print("TTS processing completed.")
1164
+ audio_style = 'style="width:210px;"'
1165
+ audio_player = f'<audio src="data:audio/wav;base64,{audio}" controls autoplay {audio_style}></audio>'
1166
+ else:
1167
+ audio_player = None
1168
+ print("Autoplay is disabled.")
1169
+ return audio_player
1170
+ except Exception as e:
1171
+ print(f"Error in texttospeech: {e}")
1172
+ return None
1173
+
1174
+
1175
+ def create_ui():
1176
+ title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
1177
+ """
1178
+ description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
1179
+
1180
+ examples = [
1181
+ ["test_images/ambass.jpg"],
1182
+ ["test_images/pearl.jpg"],
1183
+ ["test_images/Picture0.png"],
1184
+ ["test_images/Picture1.png"],
1185
+ ["test_images/Picture2.png"],
1186
+ ["test_images/Picture3.png"],
1187
+ ["test_images/Picture4.png"],
1188
+ ["test_images/Picture5.png"],
1189
+
1190
+ ]
1191
+
1192
+ with gr.Blocks(
1193
+ css=get_style(),
1194
+ theme=gr.themes.Base()
1195
+ ) as iface:
1196
  state = gr.State([])
1197
+ out_state = gr.State(None)
1198
+ click_state = gr.State([[], [], []])
1199
+ origin_image = gr.State(None)
1200
+ image_embedding = gr.State(None)
1201
+ text_refiner = gr.State(None)
1202
+ visual_chatgpt = gr.State(None)
1203
+ original_size = gr.State(None)
1204
+ input_size = gr.State(None)
1205
+ paragraph = gr.State("")
1206
  aux_state = gr.State([])
1207
+ click_index_state = gr.State((0, 0))
1208
+ input_mask_state = gr.State(np.zeros((1, 1)))
1209
+ input_points_state = gr.State([])
1210
+ input_labels_state = gr.State([])
1211
+ new_crop_save_path = gr.State(None)
1212
+ image_input_nobackground = gr.State(None)
1213
+ artist=gr.State(None)
1214
+ like_res=gr.State([])
1215
+ dislike_res=gr.State([])
1216
+
1217
+ gr.Markdown(title)
1218
+ gr.Markdown(description)
1219
+ with gr.Row(align="right", visible=False, elem_id="top_row") as top_row:
1220
+ language = gr.Dropdown(
1221
+ ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
1222
+ value="English", label="Language", interactive=True, scale=0.2, elem_classes="custom-language"
1223
+ )
1224
+ auto_play = gr.Checkbox(
1225
+ label="Check to autoplay audio", value=False, scale=0.4, elem_classes="custom-autoplay"
1226
+ )
1227
+ output_audio = gr.HTML(
1228
+ label="Synthesised Audio", scale=0.3, elem_classes="custom-output"
1229
+ )
1230
+
1231
+
1232
+ # with gr.Row(align="right",visible=False) as language_select:
1233
+ # language = gr.Dropdown(
1234
+ # ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
1235
+ # value="English", label="Language", interactive=True)
1236
+
1237
+ # with gr.Row(align="right",visible=False) as autoplay:
1238
+ # auto_play = gr.Checkbox(label="Check to autoplay audio", value=False,scale=0.4)
1239
+ # output_audio = gr.HTML(label="Synthesised Audio",scale=0.6)
1240
+
1241
+ with gr.Row():
1242
+
1243
+ with gr.Column(scale=1.0):
1244
+ with gr.Column(visible=False) as modules_not_need_gpt:
1245
+ with gr.Tab("Base(GPT Power)",visible=False) as base_tab:
1246
+ image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1247
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
1248
+ with gr.Row():
1249
+ name_label_base = gr.Button(value="Name: ")
1250
+ artist_label_base = gr.Button(value="Artist: ")
1251
+ year_label_base = gr.Button(value="Year: ")
1252
+ material_label_base = gr.Button(value="Material: ")
1253
+
1254
+ with gr.Tab("Click") as click_tab:
1255
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1256
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
1257
+ with gr.Row():
1258
+ name_label = gr.Button(value="Name: ")
1259
+ artist_label = gr.Button(value="Artist: ")
1260
+ year_label = gr.Button(value="Year: ")
1261
+ material_label = gr.Button(value="Material: ")
1262
+ with gr.Row(scale=1.0):
1263
+ with gr.Row(scale=0.8):
1264
+ focus_type = gr.Radio(
1265
+ choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
1266
+ value="CFV-D",
1267
+ label="Information Type",
1268
+ interactive=True)
1269
+ with gr.Row(scale=0.2):
1270
+ submit_button_click=gr.Button(value="Submit", interactive=True,variant='primary',size="sm")
1271
+ with gr.Row(scale=1.0):
1272
+ with gr.Row(scale=0.4):
1273
+ point_prompt = gr.Radio(
1274
+ choices=["Positive", "Negative"],
1275
+ value="Positive",
1276
+ label="Point Prompt",
1277
+ interactive=True)
1278
+ click_mode = gr.Radio(
1279
+ choices=["Continuous", "Single"],
1280
+ value="Continuous",
1281
+ label="Clicking Mode",
1282
+ interactive=True)
1283
+ with gr.Row(scale=0.4):
1284
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
1285
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
1286
+
1287
+ with gr.Tab("Trajectory (beta)") as traj_tab:
1288
+ sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
1289
+ elem_id="image_sketcher")
1290
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
1291
+ with gr.Row():
1292
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
1293
+ clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
1294
+ with gr.Row(scale=1.0):
1295
+ with gr.Row(scale=0.8):
1296
+ focus_type_sketch = gr.Radio(
1297
+ choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
1298
+ value="CFV-D",
1299
+ label="Information Type",
1300
+ interactive=True)
1301
+ Input_sketch = gr.Radio(
1302
+ choices=["Trace+Seg", "Trace"],
1303
+ value="Trace+Seg",
1304
+ label="Trace Type",
1305
+ interactive=True)
1306
+
1307
+ with gr.Column(visible=False) as modules_need_gpt1:
1308
+ with gr.Row(scale=1.0):
1309
+ sentiment = gr.Radio(
1310
+ choices=["Positive", "Natural", "Negative"],
1311
+ value="Natural",
1312
+ label="Sentiment",
1313
+ interactive=True,
1314
+ )
1315
+ with gr.Row(scale=1.0):
1316
+ factuality = gr.Radio(
1317
+ choices=["Factual", "Imagination"],
1318
+ value="Factual",
1319
+ label="Factuality",
1320
+ interactive=True,
1321
+ )
1322
+ length = gr.Slider(
1323
+ minimum=10,
1324
+ maximum=80,
1325
+ value=10,
1326
+ step=1,
1327
+ interactive=True,
1328
+ label="Generated Caption Length",
1329
+ )
1330
+ # 是否启用wiki内容整合到caption中
1331
+ enable_wiki = gr.Radio(
1332
+ choices=["Yes", "No"],
1333
+ value="No",
1334
+ label="Expert",
1335
+ interactive=True)
1336
+ with gr.Column(visible=True) as modules_not_need_gpt3:
1337
+ gr.Examples(
1338
+ examples=examples,
1339
+ inputs=[example_image],
1340
+ )
1341
+
1342
+
1343
+
1344
+
1345
+
1346
+ with gr.Column(scale=0.5):
1347
+ with gr.Column(visible=True) as module_key_input:
1348
+ openai_api_key = gr.Textbox(
1349
+ placeholder="Input openAI API key",
1350
+ show_label=False,
1351
+ label="OpenAI API Key",
1352
+ lines=1,
1353
+ type="password")
1354
+ with gr.Row(scale=0.5):
1355
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
1356
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
1357
+ variant='primary')
1358
+ with gr.Column(visible=False) as module_notification_box:
1359
+ notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
1360
+
1361
+ with gr.Column() as modules_need_gpt0:
1362
+ with gr.Column(visible=False,scale=1.0) as modules_need_gpt2:
1363
+ paragraph_output = gr.Textbox(lines=16, label="Describe Everything", max_lines=16)
1364
+ cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
1365
+
1366
+ with gr.Column(visible=False) as modules_not_need_gpt2:
1367
+ with gr.Blocks():
1368
+ chatbot = gr.Chatbot(label="Chatbox", elem_classes="chatbot",likeable=True).style(height=600, scale=0.5)
1369
+ with gr.Column(visible=False) as modules_need_gpt3:
1370
+ chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
1371
+ container=False)
1372
+ with gr.Row():
1373
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
1374
+ submit_button_text = gr.Button(value="Send", interactive=True, variant="primary")
1375
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
1376
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
1377
+
1378
+ with gr.Row():
1379
+ export_button = gr.Button(value="Export Chat Log", interactive=True, variant="primary")
1380
+ with gr.Row():
1381
+ chat_log_file = gr.File(label="Download Chat Log")
1382
+
1383
+ # TTS interface hidden initially
1384
+ with gr.Column(visible=False) as tts_interface:
1385
+ input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality")
1386
+ input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en")
1387
+ input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav")
1388
+ input_mic = gr.Audio(source="microphone", type="filepath", label="Use Microphone for Reference")
1389
+ use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False)
1390
+ agree = gr.Checkbox(label="Agree", value=True)
1391
+ output_waveform = gr.Video(label="Waveform Visual")
1392
+ # output_audio = gr.HTML(label="Synthesised Audio")
1393
+
1394
+ with gr.Row():
1395
+ submit_tts = gr.Button(value="Submit", interactive=True)
1396
+ clear_tts = gr.Button(value="Clear", interactive=True)
1397
+ ###############################################################################
1398
+ ############# this part is for text to image #############
1399
+ ###############################################################################
1400
+
1401
+ with gr.Row(variant="panel") as text2image_model:
1402
+
1403
+ with gr.Column():
1404
+ with gr.Column():
1405
+ gr.Radio([artist], label="Artist", info="Who is the artist?🧑‍🎨"),
1406
+ gr.CheckboxGroup(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"),
1407
+ gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"),
1408
+ # to be done
1409
+ gr.Dropdown(
1410
+ ["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Items", info="Which items are you interested in?"
1411
+ )
1412
+
1413
+ with gr.Row():
1414
+ prompt = gr.Text(
1415
+ label="Prompt",
1416
+ show_label=False,
1417
+ max_lines=1,
1418
+ placeholder="Enter your prompt",
1419
+ container=False,
1420
+ )
1421
+ run_button = gr.Button("Run", scale=0)
1422
+
1423
+ with gr.Accordion("Advanced options", open=True):
1424
+ num_images = gr.Slider(
1425
+ label="Number of Images",
1426
+ minimum=1,
1427
+ maximum=4,
1428
+ step=1,
1429
+ value=4,
1430
+ )
1431
+ with gr.Row():
1432
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
1433
+ negative_prompt = gr.Text(
1434
+ label="Negative prompt",
1435
+ max_lines=5,
1436
+ lines=4,
1437
+ placeholder="Enter a negative prompt",
1438
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
1439
+ visible=True,
1440
+ )
1441
+ seed = gr.Slider(
1442
+ label="Seed",
1443
+ minimum=0,
1444
+ maximum=MAX_SEED,
1445
+ step=1,
1446
+ value=0,
1447
+ )
1448
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
1449
+ with gr.Row(visible=True):
1450
+ width = gr.Slider(
1451
+ label="Width",
1452
+ minimum=100,
1453
+ maximum=MAX_IMAGE_SIZE,
1454
+ step=64,
1455
+ value=1024,
1456
+ )
1457
+ height = gr.Slider(
1458
+ label="Height",
1459
+ minimum=100,
1460
+ maximum=MAX_IMAGE_SIZE,
1461
+ step=64,
1462
+ value=1024,
1463
+ )
1464
+ with gr.Row():
1465
+ guidance_scale = gr.Slider(
1466
+ label="Guidance Scale",
1467
+ minimum=0.1,
1468
+ maximum=6,
1469
+ step=0.1,
1470
+ value=3.0,
1471
+ )
1472
+ num_inference_steps = gr.Slider(
1473
+ label="Number of inference steps",
1474
+ minimum=1,
1475
+ maximum=15,
1476
+ step=1,
1477
+ value=8,
1478
+ )
1479
+ with gr.Column():
1480
+ result = gr.Gallery(
1481
+ label="Result",
1482
+ columns=2,
1483
+ rows=2,
1484
+ show_label=False,
1485
+ allow_preview=True,
1486
+ object_fit="contain",
1487
+ height="auto",
1488
+ preview=True,
1489
+ show_share_button=True,
1490
+ show_download_button=True
1491
+ )
1492
+
1493
+
1494
+
1495
+ # gr.Examples(
1496
+ # examples=examples,
1497
+ # inputs=prompt,
1498
+ # cache_examples=False
1499
+ # )
1500
+
1501
+ use_negative_prompt.change(
1502
+ fn=lambda x: gr.update(visible=x),
1503
+ inputs=use_negative_prompt,
1504
+ outputs=negative_prompt,
1505
+ api_name=False,
1506
+ )
1507
+
1508
+ # gr.on(
1509
+ # triggers=[
1510
+ # prompt.submit,
1511
+ # negative_prompt.submit,
1512
+ # run_button.click,
1513
+ # ],
1514
+ # fn=generate,
1515
+ # inputs=[
1516
+ # prompt,
1517
+ # negative_prompt,
1518
+ # use_negative_prompt,
1519
+ # seed,
1520
+ # width,
1521
+ # height,
1522
+ # guidance_scale,
1523
+ # num_inference_steps,
1524
+ # randomize_seed,
1525
+ # num_images
1526
+ # ],
1527
+ # outputs=[result, seed],
1528
+ # api_name="run",
1529
+ # )
1530
+ run_button.click(
1531
+ fn=generate,
1532
+ inputs=[
1533
+ prompt,
1534
+ negative_prompt,
1535
+ use_negative_prompt,
1536
+ seed,
1537
+ width,
1538
+ height,
1539
+ guidance_scale,
1540
+ num_inference_steps,
1541
+ randomize_seed,
1542
+ num_images
1543
+ ],
1544
+ outputs=[result, seed]
1545
+ )
1546
+
1547
+ ###############################################################################
1548
+ ############# above part is for text to image #############
1549
+ ###############################################################################
1550
+
1551
+
1552
+ ###############################################################################
1553
+ # this part is for 3d generate.
1554
+ ###############################################################################
1555
+
1556
+ with gr.Row(variant="panel",visible=False) as d3_model:
1557
+ with gr.Column():
1558
+ with gr.Row():
1559
+ input_image = gr.Image(
1560
+ label="Input Image",
1561
+ image_mode="RGBA",
1562
+ sources="upload",
1563
+ #width=256,
1564
+ #height=256,
1565
+ type="pil",
1566
+ elem_id="content_image",
1567
+ )
1568
+ processed_image = gr.Image(
1569
+ label="Processed Image",
1570
+ image_mode="RGBA",
1571
+ #width=256,
1572
+ #height=256,
1573
+ type="pil",
1574
+ interactive=False
1575
+ )
1576
+ with gr.Row():
1577
+ with gr.Group():
1578
+ do_remove_background = gr.Checkbox(
1579
+ label="Remove Background", value=True
1580
+ )
1581
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
1582
+
1583
+ sample_steps = gr.Slider(
1584
+ label="Sample Steps",
1585
+ minimum=30,
1586
+ maximum=75,
1587
+ value=75,
1588
+ step=5
1589
+ )
1590
+
1591
+ with gr.Row():
1592
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
1593
+
1594
+ with gr.Row(variant="panel"):
1595
+ gr.Examples(
1596
+ examples=[
1597
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
1598
+ ],
1599
+ inputs=[input_image],
1600
+ label="Examples",
1601
+ cache_examples=False,
1602
+ examples_per_page=16
1603
+ )
1604
+
1605
+ with gr.Column():
1606
+
1607
+ with gr.Row():
1608
+
1609
+ with gr.Column():
1610
+ mv_show_images = gr.Image(
1611
+ label="Generated Multi-views",
1612
+ type="pil",
1613
+ width=379,
1614
+ interactive=False
1615
+ )
1616
+
1617
+ # with gr.Column():
1618
+ # output_video = gr.Video(
1619
+ # label="video", format="mp4",
1620
+ # width=379,
1621
+ # autoplay=True,
1622
+ # interactive=False
1623
+ # )
1624
+
1625
+ with gr.Row():
1626
+ with gr.Tab("OBJ"):
1627
+ output_model_obj = gr.Model3D(
1628
+ label="Output Model (OBJ Format)",
1629
+ interactive=False,
1630
+ )
1631
+ gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
1632
+ with gr.Tab("GLB"):
1633
+ output_model_glb = gr.Model3D(
1634
+ label="Output Model (GLB Format)",
1635
+ interactive=False,
1636
+ )
1637
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
1638
+
1639
+
1640
+
1641
+
1642
+ mv_images = gr.State()
1643
+
1644
+ # chatbot.like(handle_like_dislike, inputs=[like_state, dislike_state], outputs=[like_state, dislike_state])
1645
+
1646
+ submit.click(fn=check_input_image, inputs=[new_crop_save_path], outputs=[processed_image]).success(
1647
+ fn=generate_mvs,
1648
+ inputs=[processed_image, sample_steps, sample_seed],
1649
+ outputs=[mv_images, mv_show_images]
1650
+
1651
+ ).success(
1652
+ fn=make3d,
1653
+ inputs=[mv_images],
1654
+ outputs=[output_model_obj, output_model_glb]
1655
+ )
1656
+
1657
+ ###############################################################################
1658
+ # above part is for 3d generate.
1659
+ ###############################################################################
1660
+
1661
+
1662
+ def clear_tts_fields():
1663
+ return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
1664
+
1665
+ # submit_tts.click(
1666
+ # tts.predict,
1667
+ # inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree],
1668
+ # outputs=[output_waveform, output_audio],
1669
+ # queue=True
1670
+ # )
1671
+
1672
+ clear_tts.click(
1673
+ clear_tts_fields,
1674
+ inputs=None,
1675
+ outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio],
1676
+ queue=False
1677
+ )
1678
+
1679
+
1680
+
1681
+
1682
+ clear_button_sketcher.click(
1683
+ lambda x: (x),
1684
+ [origin_image],
1685
+ [sketcher_input],
1686
+ queue=False,
1687
+ show_progress=False
1688
+ )
1689
+
1690
+
1691
+
1692
+
1693
+
1694
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
1695
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
1696
+ modules_not_need_gpt2, tts_interface,module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
1697
+ enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
1698
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
1699
+ modules_not_need_gpt,
1700
+ modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
1701
+ disable_chatGPT_button.click(init_wo_openai_api_key,
1702
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
1703
+ modules_not_need_gpt,
1704
+ modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
1705
+
1706
+ enable_chatGPT_button.click(
1707
+ lambda: (None, [], [], [[], [], []], "", "", ""),
1708
+ [],
1709
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1710
+ queue=False,
1711
+ show_progress=False
1712
+ )
1713
+ openai_api_key.submit(
1714
+ lambda: (None, [], [], [[], [], []], "", "", ""),
1715
+ [],
1716
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1717
+ queue=False,
1718
+ show_progress=False
1719
+ )
1720
+
1721
+ cap_everything_button.click(cap_everything, [paragraph, visual_chatgpt, language,auto_play],
1722
+ [paragraph_output,output_audio])
1723
+
1724
+ clear_button_click.click(
1725
+ lambda x: ([[], [], []], x),
1726
+ [origin_image],
1727
+ [click_state, image_input],
1728
+ queue=False,
1729
+ show_progress=False
1730
+ )
1731
+ clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
1732
+ clear_button_image.click(
1733
+ lambda: (None, [], [], [[], [], []], "", "", ""),
1734
+ [],
1735
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1736
+ queue=False,
1737
+ show_progress=False
1738
+ )
1739
+ clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
1740
+ clear_button_text.click(
1741
+ lambda: ([], [], [[], [], [], []]),
1742
+ [],
1743
+ [chatbot, state, click_state],
1744
+ queue=False,
1745
+ show_progress=False
1746
+ )
1747
+ clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
1748
+
1749
+ image_input.clear(
1750
+ lambda: (None, [], [], [[], [], []], "", "", ""),
1751
+ [],
1752
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1753
+ queue=False,
1754
+ show_progress=False
1755
+ )
1756
+
1757
+ image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
1758
+
1759
+
1760
+
1761
+
1762
+ image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key],
1763
+ [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1764
+ image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1765
+
1766
+ image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
1767
+ [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1768
+ image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1769
+ sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
1770
+ [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1771
+ image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1772
+ chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
1773
+ [chatbot, state, aux_state,output_audio])
1774
+ chat_input.submit(lambda: "", None, chat_input)
1775
+ submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
1776
+ [chatbot, state, aux_state,output_audio])
1777
+ submit_button_text.click(lambda: "", None, chat_input)
1778
+ example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key],
1779
+ [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1780
+ image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1781
+
1782
+ example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
1783
+
1784
+ def on_click_tab_selected():
1785
+ if gpt_state ==1:
1786
+ print(gpt_state)
1787
+ print("using gpt")
1788
+ return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
1789
+ else:
1790
+ print("no gpt")
1791
+ print("gpt_state",gpt_state)
1792
+ return [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2
1793
+
1794
+ def on_base_selected():
1795
+ if gpt_state ==1:
1796
+ print(gpt_state)
1797
+ print("using gpt")
1798
+ return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
1799
+ else:
1800
+ print("no gpt")
1801
+ return [gr.update(visible=False)]*4
1802
+
1803
+
1804
+ traj_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
1805
+ click_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
1806
+ base_tab.select(on_base_selected, outputs=[modules_need_gpt0,modules_need_gpt2,modules_not_need_gpt2,modules_need_gpt1])
1807
+
1808
+
1809
+
1810
+
1811
+ image_input.select(
1812
+ inference_click,
1813
+ inputs=[
1814
+ origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
1815
+ image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
1816
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
1817
+ ],
1818
+ outputs=[chatbot, state, click_state, image_input, input_image, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground],
1819
+ show_progress=False, queue=True
1820
+ )
1821
+
1822
+
1823
+ submit_button_click.click(
1824
+ submit_caption,
1825
+ inputs=[
1826
+ state, text_refiner,length, sentiment, factuality, language,
1827
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
1828
+ auto_play,paragraph,focus_type,openai_api_key,new_crop_save_path
1829
+ ],
1830
+ outputs=[
1831
+ chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,
1832
+ output_audio
1833
+ ],
1834
+ show_progress=True,
1835
+ queue=True
1836
+ )
1837
+
1838
+
1839
+ submit_button_sketcher.click(
1840
+ inference_traject,
1841
+ inputs=[
1842
+ origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
1843
+ original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
1844
+ ],
1845
+ outputs=[chatbot, state, sketcher_input,output_audio],
1846
+ show_progress=False, queue=True
1847
+ )
1848
+
1849
+ export_button.click(
1850
+ export_chat_log,
1851
+ inputs=[state,paragraph,like_res,dislike_res],
1852
+ outputs=[chat_log_file],
1853
+ queue=True
1854
+ )
1855
+
1856
+ upvote_btn.click(
1857
+ handle_liked,
1858
+ inputs=[state,like_res],
1859
+ outputs=[chatbot,like_res]
1860
+ )
1861
+
1862
+ downvote_btn.click(
1863
+ handle_disliked,
1864
+ inputs=[state,dislike_res],
1865
+ outputs=[chatbot,dislike_res]
1866
+ )
1867
+
1868
+
1869
+
1870
+
1871
+
1872
+ return iface
1873
+
1874
+
1875
+ if __name__ == '__main__':
1876
+ iface = create_ui()
1877
+ iface.queue(concurrency_count=5, api_open=False, max_size=10)
1878
+ iface.launch(server_name="0.0.0.0", enable_queue=True)