Niki Zhang commited on
Commit
aea9e97
·
verified ·
1 Parent(s): c6473d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -1774
app.py CHANGED
@@ -1,1797 +1,226 @@
1
- from io import BytesIO
2
- from math import inf
 
3
  import os
4
- import base64
5
- import json
6
  import gradio as gr
7
- import numpy as np
8
- from gradio import processing_utils
9
- import requests
10
- from packaging import version
11
- from PIL import Image, ImageDraw
12
- import functools
13
- import emoji
14
- from langchain.llms.openai import OpenAI
15
- from caption_anything.model import CaptionAnything
16
- from caption_anything.utils.image_editing_utils import create_bubble_frame
17
- from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
18
- from caption_anything.utils.parser import parse_augment
19
- from caption_anything.captioner import build_captioner
20
- from caption_anything.text_refiner import build_text_refiner
21
- from caption_anything.segmenter import build_segmenter
22
- from chatbox import ConversationBot, build_chatbot_tools, get_new_image_name
23
- from segment_anything import sam_model_registry
24
- import easyocr
25
  import re
26
- import edge_tts
27
- import asyncio
28
- import cv2
29
- # import tts
30
-
31
- ###############################################################################
32
- ############# this part is for 3D generate #############
33
- ###############################################################################
34
-
35
-
36
- # import spaces #
37
-
38
- import os
39
- import imageio
40
  import numpy as np
41
- import torch
42
- import rembg
43
- from PIL import Image
44
- from torchvision.transforms import v2
45
- from pytorch_lightning import seed_everything
46
- from omegaconf import OmegaConf
47
- from einops import rearrange, repeat
48
- from tqdm import tqdm
49
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
50
-
51
- from src.utils.train_util import instantiate_from_config
52
- from src.utils.camera_util import (
53
- FOV_to_intrinsics,
54
- get_zero123plus_input_cameras,
55
- get_circular_camera_poses,
56
- )
57
- from src.utils.mesh_util import save_obj, save_glb
58
- from src.utils.infer_util import remove_background, resize_foreground, images_to_video
59
-
60
- import tempfile
61
- from functools import partial
62
-
63
- from huggingface_hub import hf_hub_download
64
-
65
-
66
-
67
-
68
- def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
69
- """
70
- Get the rendering camera parameters.
71
- """
72
- c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
73
- if is_flexicubes:
74
- cameras = torch.linalg.inv(c2ws)
75
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
76
- else:
77
- extrinsics = c2ws.flatten(-2)
78
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
79
- cameras = torch.cat([extrinsics, intrinsics], dim=-1)
80
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
81
- return cameras
82
-
83
-
84
- def images_to_video(images, output_path, fps=30):
85
- # images: (N, C, H, W)
86
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
87
- frames = []
88
- for i in range(images.shape[0]):
89
- frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
90
- assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
91
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
92
- assert frame.min() >= 0 and frame.max() <= 255, \
93
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
94
- frames.append(frame)
95
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
96
-
97
-
98
- ###############################################################################
99
- # Configuration.
100
- ###############################################################################
101
-
102
- import shutil
103
-
104
- def find_cuda():
105
- # Check if CUDA_HOME or CUDA_PATH environment variables are set
106
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
107
-
108
- if cuda_home and os.path.exists(cuda_home):
109
- return cuda_home
110
-
111
- # Search for the nvcc executable in the system's PATH
112
- nvcc_path = shutil.which('nvcc')
113
-
114
- if nvcc_path:
115
- # Remove the 'bin/nvcc' part to get the CUDA installation path
116
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
117
- return cuda_path
118
-
119
- return None
120
-
121
- cuda_path = find_cuda()
122
-
123
- if cuda_path:
124
- print(f"CUDA installation found at: {cuda_path}")
125
- else:
126
- print("CUDA installation not found")
127
-
128
- config_path = 'configs/instant-nerf-base.yaml'
129
- config = OmegaConf.load(config_path)
130
- config_name = os.path.basename(config_path).replace('.yaml', '')
131
- model_config = config.model_config
132
- infer_config = config.infer_config
133
-
134
- IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
135
-
136
- device = torch.device('cuda')
137
-
138
- # load diffusion model
139
- print('Loading diffusion model ...')
140
- pipeline = DiffusionPipeline.from_pretrained(
141
- "sudo-ai/zero123plus-v1.2",
142
- custom_pipeline="zero123plus",
143
- torch_dtype=torch.float16,
144
- )
145
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
146
- pipeline.scheduler.config, timestep_spacing='trailing'
147
- )
148
-
149
- # load custom white-background UNet
150
- unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
151
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
152
- pipeline.unet.load_state_dict(state_dict, strict=True)
153
-
154
- pipeline = pipeline.to(device)
155
-
156
- # load reconstruction model
157
- print('Loading reconstruction model ...')
158
- model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_nerf_base.ckpt", repo_type="model")
159
- model0 = instantiate_from_config(model_config)
160
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
161
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
162
- model0.load_state_dict(state_dict, strict=True)
163
-
164
- model0 = model0.to(device)
165
-
166
- print('Loading Finished!')
167
-
168
-
169
- def check_input_image(input_image):
170
- if input_image is None:
171
- raise gr.Error("No image uploaded!")
172
- image = None
173
- else:
174
- image = Image.open(input_image)
175
- return image
176
-
177
- def preprocess(input_image, do_remove_background):
178
-
179
- rembg_session = rembg.new_session() if do_remove_background else None
180
-
181
- if do_remove_background:
182
- input_image = remove_background(input_image, rembg_session)
183
- input_image = resize_foreground(input_image, 0.85)
184
-
185
- return input_image
186
-
187
-
188
- # @spaces.GPU
189
- def generate_mvs(input_image, sample_steps, sample_seed):
190
-
191
- seed_everything(sample_seed)
192
-
193
- # sampling
194
- z123_image = pipeline(
195
- input_image,
196
- num_inference_steps=sample_steps
197
- ).images[0]
198
-
199
- show_image = np.asarray(z123_image, dtype=np.uint8)
200
- show_image = torch.from_numpy(show_image) # (960, 640, 3)
201
- show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
202
- show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
203
- show_image = Image.fromarray(show_image.numpy())
204
-
205
- return z123_image, show_image
206
-
207
-
208
- # @spaces.GPU
209
- def make3d(images):
210
-
211
- global model0
212
- if IS_FLEXICUBES:
213
- model0.init_flexicubes_geometry(device)
214
- model0 = model0.eval()
215
-
216
- images = np.asarray(images, dtype=np.float32) / 255.0
217
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
218
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
219
-
220
- input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
221
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
222
-
223
- images = images.unsqueeze(0).to(device)
224
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
225
-
226
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
227
- print(mesh_fpath)
228
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
229
- mesh_dirname = os.path.dirname(mesh_fpath)
230
- video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
231
- mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
232
-
233
- with torch.no_grad():
234
- # get triplane
235
- planes = model0.forward_planes(images, input_cameras)
236
-
237
- # # get video
238
- # chunk_size = 20 if IS_FLEXICUBES else 1
239
- # render_size = 384
240
-
241
- # frames = []
242
- # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
243
- # if IS_FLEXICUBES:
244
- # frame = model.forward_geometry(
245
- # planes,
246
- # render_cameras[:, i:i+chunk_size],
247
- # render_size=render_size,
248
- # )['img']
249
- # else:
250
- # frame = model.synthesizer(
251
- # planes,
252
- # cameras=render_cameras[:, i:i+chunk_size],
253
- # render_size=render_size,
254
- # )['images_rgb']
255
- # frames.append(frame)
256
- # frames = torch.cat(frames, dim=1)
257
-
258
- # images_to_video(
259
- # frames[0],
260
- # video_fpath,
261
- # fps=30,
262
- # )
263
-
264
- # print(f"Video saved to {video_fpath}")
265
-
266
- # get mesh
267
- mesh_out = model0.extract_mesh(
268
- planes,
269
- use_texture_map=False,
270
- **infer_config,
271
- )
272
-
273
- vertices, faces, vertex_colors = mesh_out
274
- vertices = vertices[:, [1, 2, 0]]
275
-
276
- save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
277
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
278
-
279
- print(f"Mesh saved to {mesh_fpath}")
280
-
281
- return mesh_fpath, mesh_glb_fpath
282
-
283
-
284
- ###############################################################################
285
- ############# above part is for 3D generate #############
286
- ###############################################################################
287
-
288
-
289
- ###############################################################################
290
- ############# this part is for text to image #############
291
- ###############################################################################
292
-
293
- # Use environment variables for flexibility
294
- MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
295
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
296
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
297
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
298
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
299
-
300
- # Determine device and load model outside of function for efficiency
301
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
302
- pipe = StableDiffusionXLPipeline.from_pretrained(
303
- MODEL_ID,
304
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
305
- use_safetensors=True,
306
- add_watermarker=False,
307
- ).to(device)
308
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
309
-
310
- # Torch compile for potential speedup (experimental)
311
- if USE_TORCH_COMPILE:
312
- pipe.compile()
313
-
314
- # CPU offloading for larger RAM capacity (experimental)
315
- if ENABLE_CPU_OFFLOAD:
316
- pipe.enable_model_cpu_offload()
317
-
318
- MAX_SEED = np.iinfo(np.int32).max
319
-
320
- def save_image(img):
321
- unique_name = str(uuid.uuid4()) + ".png"
322
- img.save(unique_name)
323
- return unique_name
324
 
325
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
326
- if randomize_seed:
327
- seed = random.randint(0, MAX_SEED)
328
- return seed
329
-
330
- # @spaces.GPU(duration=30, queue=False)
331
- def generate(
332
- prompt: str,
333
- negative_prompt: str = "",
334
- use_negative_prompt: bool = False,
335
- seed: int = 1,
336
- width: int = 1024,
337
- height: int = 1024,
338
- guidance_scale: float = 3,
339
- num_inference_steps: int = 30,
340
- randomize_seed: bool = False,
341
- num_images: int = 4, # Number of images to generate
342
- use_resolution_binning: bool = True,
343
- progress=gr.Progress(track_tqdm=True),
344
- ):
345
- seed = int(randomize_seed_fn(seed, randomize_seed))
346
- generator = torch.Generator(device=device).manual_seed(seed)
347
-
348
- # Improved options handling
349
- options = {
350
- "prompt": [prompt] * num_images,
351
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
352
- "width": width,
353
- "height": height,
354
- "guidance_scale": guidance_scale,
355
- "num_inference_steps": num_inference_steps,
356
- "generator": generator,
357
- "output_type": "pil",
358
- }
359
-
360
- # Use resolution binning for faster generation with less VRAM usage
361
- # if use_resolution_binning:
362
- # options["use_resolution_binning"] = True
363
-
364
- # Generate images potentially in batches
365
- images = []
366
- for i in range(0, num_images, BATCH_SIZE):
367
- batch_options = options.copy()
368
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
369
- if "negative_prompt" in batch_options:
370
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
371
- images.extend(pipe(**batch_options).images)
372
 
373
- image_paths = [save_image(img) for img in images]
374
- return image_paths, seed
375
 
376
- examples = [
377
- "a cat eating a piece of cheese",
378
- "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
379
- "Ironman VS Hulk, ultrarealistic",
380
- "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
381
- "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
382
- "Kids going to school, Anime style"
383
- ]
384
 
 
385
 
386
 
 
 
387
 
388
- ###############################################################################
389
- ############# above part is for text to image #############
390
- ###############################################################################
391
 
 
392
 
393
- css = """
394
- #warning {background-color: #FFCCCB}
395
- .chatbot {
396
- padding: 0 !important;
397
- margin: 0 !important;
398
- }
399
- """
400
- filtered_language_dict = {
401
- 'English': 'en-US-JennyNeural',
402
- 'Chinese': 'zh-CN-XiaoxiaoNeural',
403
- 'French': 'fr-FR-DeniseNeural',
404
- 'Spanish': 'es-MX-DaliaNeural',
405
- 'Arabic': 'ar-SA-ZariyahNeural',
406
- 'Portuguese': 'pt-BR-FranciscaNeural',
407
- 'Cantonese': 'zh-HK-HiuGaaiNeural'
408
- }
409
 
410
- focus_map = {
411
- "CFV-D":0,
412
- "CFV-DA":1,
413
- "CFV-DAI":2,
414
- "PFV-DDA":3
415
- }
416
 
417
- '''
418
- prompt_list = [
419
- 'Wiki_caption: {Wiki_caption}, you have to generate a caption according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
420
- 'Wiki_caption: {Wiki_caption}, you have to select sentences from wiki caption that describe the surrounding objects that may be associated with the picture object. Around {length} words of {sentiment} sentiment in {language}.',
421
- 'Wiki_caption: {Wiki_caption}. You have to choose sentences from the wiki caption that describe unrelated objects to the image. Around {length} words of {sentiment} sentiment in {language}.',
422
- 'Wiki_caption: {Wiki_caption}. You have to choose sentences from the wiki caption that describe unrelated objects to the image. Around {length} words of {sentiment} sentiment in {language}.'
423
- ]
424
 
425
- prompt_list = [
426
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the object but does not include analysis)as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
427
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
428
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
429
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and the objects that may be related to the selected object and list one fact of selected object, one fact of related object and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.'
430
- ]
431
- '''
432
- prompt_list = [
433
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the selected object but does not include analysis)as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
434
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
435
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
436
- 'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and the objects that may be related to the selected object and list one fact of selected object, one fact of related object and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.'
437
- ]
438
-
439
-
440
- gpt_state = 0
441
- VOICE = "en-GB-SoniaNeural"
442
- article = """
443
- <div style='margin:20px auto;'>
444
- <p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p>
445
- </div>
446
  """
447
 
448
- args = parse_augment()
449
- args.segmenter = "huge"
450
- args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
451
- args.clip_filter = True
452
- if args.segmenter_checkpoint is None:
453
- _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
454
- else:
455
- segmenter_checkpoint = args.segmenter_checkpoint
456
-
457
- shared_captioner = build_captioner(args.captioner, args.device, args)
458
- shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
459
- ocr_lang = ["ch_tra", "en"]
460
- shared_ocr_reader = easyocr.Reader(ocr_lang)
461
- tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
462
- shared_chatbot_tools = build_chatbot_tools(tools_dict)
463
-
464
-
465
- class ImageSketcher(gr.Image):
466
- """
467
- Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
468
- """
469
-
470
- is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
471
-
472
- def __init__(self, **kwargs):
473
- super().__init__(tool="sketch", **kwargs)
474
-
475
- def preprocess(self, x):
476
- if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
477
- assert isinstance(x, dict)
478
- if x['mask'] is None:
479
- decode_image = processing_utils.decode_base64_to_image(x['image'])
480
- width, height = decode_image.size
481
- mask = np.zeros((height, width, 4), dtype=np.uint8)
482
- mask[..., -1] = 255
483
- mask = self.postprocess(mask)
484
- x['mask'] = mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
- return super().preprocess(x)
487
-
488
-
489
- def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
490
- session_id=None):
491
- segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
492
- captioner = captioner
493
- if session_id is not None:
494
- print('Init caption anything for session {}'.format(session_id))
495
- return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
496
-
497
-
498
- def validate_api_key(api_key):
499
- api_key = str(api_key).strip()
500
- print(api_key)
501
- try:
502
- test_llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
503
- response = test_llm("Test API call")
504
- print(response)
505
- return True
506
- except Exception as e:
507
- print(f"API key validation failed: {e}")
508
- return False
509
-
510
-
511
- def init_openai_api_key(api_key=""):
512
- text_refiner = None
513
- visual_chatgpt = None
514
- if api_key and len(api_key) > 30:
515
- print(api_key)
516
- if validate_api_key(api_key):
517
- try:
518
- text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
519
- assert len(text_refiner.llm('hi')) > 0 # test
520
- visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
521
- except Exception as e:
522
- print(f"Error initializing TextRefiner or ConversationBot: {e}")
523
- text_refiner = None
524
- visual_chatgpt = None
525
- else:
526
- print("Invalid API key.")
527
- else:
528
- print("API key is too short.")
529
- print(text_refiner)
530
- openai_available = text_refiner is not None
531
- if openai_available:
532
-
533
- global gpt_state
534
- gpt_state=1
535
- # return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*3
536
- return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*2
537
- else:
538
- gpt_state=0
539
- # return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*3
540
- return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*2
541
-
542
- def init_wo_openai_api_key():
543
- global gpt_state
544
- gpt_state=0
545
- # return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*3
546
- return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*2
547
-
548
- def get_click_prompt(chat_input, click_state, click_mode):
549
- inputs = json.loads(chat_input)
550
- if click_mode == 'Continuous':
551
- points = click_state[0]
552
- labels = click_state[1]
553
- for input in inputs:
554
- points.append(input[:2])
555
- labels.append(input[2])
556
- elif click_mode == 'Single':
557
- points = []
558
- labels = []
559
- for input in inputs:
560
- points.append(input[:2])
561
- labels.append(input[2])
562
- click_state[0] = points
563
- click_state[1] = labels
564
- else:
565
- raise NotImplementedError
566
-
567
- prompt = {
568
- "prompt_type": ["click"],
569
- "input_point": click_state[0],
570
- "input_label": click_state[1],
571
- "multimask_output": "True",
572
- }
573
- return prompt
574
-
575
-
576
- def update_click_state(click_state, caption, click_mode):
577
- if click_mode == 'Continuous':
578
- click_state[2].append(caption)
579
- elif click_mode == 'Single':
580
- click_state[2] = [caption]
581
- else:
582
- raise NotImplementedError
583
-
584
- async def chat_input_callback(*args):
585
- visual_chatgpt, chat_input, click_state, state, aux_state ,language , autoplay = args
586
- if visual_chatgpt is not None:
587
- state, _, aux_state, _ = visual_chatgpt.run_text(chat_input, state, aux_state)
588
- last_text, last_response = state[-1]
589
- print("last response",last_response)
590
- audio = await texttospeech(last_response,language,autoplay)
591
- return state, state, aux_state,audio
592
- else:
593
- response = "Text refiner is not initilzed, please input openai api key."
594
- state = state + [(chat_input, response)]
595
- audio = await texttospeech(response,language,autoplay)
596
- return state, state, None,audio
597
-
598
-
599
-
600
- def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
601
- if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
602
- image_input, mask = image_input['image'], image_input['mask']
603
-
604
- click_state = [[], [], []]
605
- image_input = image_resize(image_input, res=1024)
606
-
607
- model = build_caption_anything_with_models(
608
- args,
609
- api_key="",
610
- captioner=shared_captioner,
611
- sam_model=shared_sam_model,
612
- ocr_reader=shared_ocr_reader,
613
- session_id=iface.app_id
614
- )
615
- model.segmenter.set_image(image_input)
616
- image_embedding = model.image_embedding
617
- original_size = model.original_size
618
- input_size = model.input_size
619
-
620
- if visual_chatgpt is not None:
621
- print('upload_callback: add caption to chatGPT memory')
622
- new_image_path = get_new_image_name('chat_image', func_name='upload')
623
- image_input.save(new_image_path)
624
- visual_chatgpt.current_image = new_image_path
625
- img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
626
- Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
627
- AI_prompt = "Received."
628
- visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
629
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
630
- parsed_data = get_image_gpt(openai_api_key, new_image_path,"Please provide the name, artist, year of creation, and material used for this painting. Return the information in dictionary format without any newline characters. If any information is unavailable, return \"None\" for that field. Format as follows: { \"name\": \"Name of the painting\",\"artist\": \"Name of the artist\", \"year\": \"Year of creation\", \"material\": \"Material used in the painting\" }.")
631
- parsed_data = json.loads(parsed_data.replace("'", "\""))
632
- name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["material"]
633
- # artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
634
- paragraph = get_image_gpt(openai_api_key, new_image_path,f"What's going on in this picture? in {language}")
635
-
636
- state = [
637
- (
638
- None,
639
- f"🤖 Hi, I am EyeSee. Let's explore this painting {name} together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information."
640
- )
641
- ]
642
-
643
- return state, state, image_input, click_state, image_input, image_input, image_input, image_embedding, \
644
- original_size, input_size, f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Material: {material}",f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Material: {material}",paragraph
645
-
646
-
647
 
648
 
649
- def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
650
- length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
651
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData):
652
- click_index = evt.index
653
-
654
- if point_prompt == 'Positive':
655
- coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
656
- else:
657
- coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
658
-
659
- prompt = get_click_prompt(coordinate, click_state, click_mode)
660
- input_points = prompt['input_point']
661
- input_labels = prompt['input_label']
662
-
663
- controls = {'length': length,
664
- 'sentiment': sentiment,
665
- 'factuality': factuality,
666
- 'language': language}
667
-
668
- model = build_caption_anything_with_models(
669
- args,
670
- api_key="",
671
- captioner=shared_captioner,
672
- sam_model=shared_sam_model,
673
- ocr_reader=shared_ocr_reader,
674
- text_refiner=text_refiner,
675
- session_id=iface.app_id
676
- )
677
-
678
- model.setup(image_embedding, original_size, input_size, is_image_set=True)
679
-
680
- enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
681
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
682
-
683
- state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
684
- update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
685
- text = out['generated_captions']['raw_caption']
686
- input_mask = np.array(out['mask'].convert('P'))
687
- image_input_nobackground = mask_painter(np.array(image_input), input_mask,background_alpha=0)
688
- image_input_withbackground=mask_painter(np.array(image_input), input_mask)
689
-
690
- click_index_state = click_index
691
- input_mask_state = input_mask
692
- input_points_state = input_points
693
- input_labels_state = input_labels
694
- out_state = out
695
-
696
- if visual_chatgpt is not None:
697
- print('inference_click: add caption to chatGPT memory')
698
- new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
699
- Image.open(out["crop_save_path"]).save(new_crop_save_path)
700
- point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
701
- visual_chatgpt.point_prompt = point_prompt
702
-
703
-
704
- print("new crop save",new_crop_save_path)
705
-
706
- yield state, state, click_state, image_input_nobackground, image_input_withbackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
707
-
708
-
709
-
710
-
711
-
712
- async def submit_caption(state, text_refiner, length, sentiment, factuality, language,
713
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
714
- autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
715
- print("state",state)
716
-
717
- click_index = click_index_state
718
-
719
- # if pre_click_index==click_index:
720
- # click_index = (click_index[0] - 1, click_index[1] - 1)
721
- # pre_click_index = click_index
722
- # else:
723
- # pre_click_index = click_index
724
- print("click_index",click_index)
725
- print("input_points_state",input_points_state)
726
- print("input_labels_state",input_labels_state)
727
-
728
- prompt=generate_prompt(paragraph,focus_type,length,sentiment,factuality,language)
729
-
730
- print("Prompt:", prompt)
731
- print("click",click_index)
732
-
733
- # image_input = create_bubble_frame(np.array(image_input), generated_caption, click_index, input_mask,
734
- # input_points=input_points, input_labels=input_labels)
735
-
736
-
737
- if not args.disable_gpt and text_refiner:
738
- print("new crop save",new_crop_save_path)
739
- focus_info=get_image_gpt(openai_api_key,new_crop_save_path,prompt)
740
- if focus_info.startswith('"') and focus_info.endswith('"'):
741
- focus_info=focus_info[1:-1]
742
- focus_info=focus_info.replace('#', '')
743
- # state = state + [(None, f"Wiki: {paragraph}")]
744
- state = state + [(None, f"{focus_info}")]
745
- print("new_cap",focus_info)
746
- read_info = re.sub(r'[#[\]!*]','',focus_info)
747
- read_info = emoji.replace_emoji(read_info,replace="")
748
- print("read info",read_info)
749
-
750
- # refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask,
751
- # input_points=input_points, input_labels=input_labels)
752
- try:
753
- audio_output = await texttospeech(read_info, language,autoplay)
754
- # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
755
- return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
756
-
757
- except Exception as e:
758
- state = state + [(None, f"Error during TTS prediction: {str(e)}")]
759
- print(f"Error during TTS prediction: {str(e)}")
760
- # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
761
- return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
762
-
763
- else:
764
- try:
765
- audio_output = await texttospeech(focus_info, language, autoplay)
766
- # waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
767
- # return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
768
- return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
769
-
770
- except Exception as e:
771
- state = state + [(None, f"Error during TTS prediction: {str(e)}")]
772
- print(f"Error during TTS prediction: {str(e)}")
773
- return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
774
-
775
- def generate_prompt(focus_type, paragraph,length, sentiment, factuality, language):
776
-
777
- mapped_value = focus_map.get(focus_type, -1)
778
-
779
- controls = {
780
- 'length': length,
781
- 'sentiment': sentiment,
782
- 'factuality': factuality,
783
- 'language': language
784
- }
785
-
786
- if mapped_value != -1:
787
- prompt = prompt_list[mapped_value].format(
788
- Wiki_caption=paragraph,
789
- length=controls['length'],
790
- sentiment=controls['sentiment'],
791
- language=controls['language']
792
- )
793
- else:
794
- prompt = "Invalid focus type."
795
-
796
- if controls['factuality'] == "Imagination":
797
- prompt += " Assuming that I am someone who has viewed a lot of art and has a lot of experience viewing art. Explain artistic features (composition, color, style, or use of light) and discuss the symbolism of the content and its influence on later artistic movements."
798
-
799
- return prompt
800
-
801
-
802
- def encode_image(image_path):
803
- with open(image_path, "rb") as image_file:
804
- return base64.b64encode(image_file.read()).decode('utf-8')
805
-
806
- def get_image_gpt(api_key, image_path,prompt,enable_wiki=None):
807
- # Getting the base64 string
808
- base64_image = encode_image(image_path)
809
-
810
-
811
-
812
- headers = {
813
- "Content-Type": "application/json",
814
- "Authorization": f"Bearer {api_key}"
815
- }
816
-
817
- prompt_text = prompt
818
-
819
- payload = {
820
- "model": "gpt-4o",
821
- "messages": [
822
- {
823
- "role": "user",
824
- "content": [
825
- {
826
- "type": "text",
827
- "text": prompt_text
828
- },
829
- {
830
- "type": "image_url",
831
- "image_url": {
832
- "url": f"data:image/jpeg;base64,{base64_image}"
833
- }
834
- }
835
- ]
836
- }
837
- ],
838
- "max_tokens": 300
839
- }
840
-
841
- # Sending the request to the OpenAI API
842
- response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
843
- result = response.json()
844
- print(result)
845
- content = result['choices'][0]['message']['content']
846
- # Assume the model returns a valid JSON string in 'content'
847
- try:
848
- return content
849
- except json.JSONDecodeError:
850
- return {"error": "Failed to parse model output"}
851
-
852
-
853
-
854
-
855
- def get_sketch_prompt(mask: Image.Image):
856
- """
857
- Get the prompt for the sketcher.
858
- TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
859
- """
860
-
861
- mask = np.asarray(mask)[..., 0]
862
-
863
- # Get the bounding box of the sketch
864
- y, x = np.where(mask != 0)
865
- x1, y1 = np.min(x), np.min(y)
866
- x2, y2 = np.max(x), np.max(y)
867
-
868
- prompt = {
869
- 'prompt_type': ['box'],
870
- 'input_boxes': [
871
- [x1, y1, x2, y2]
872
- ]
873
- }
874
-
875
- return prompt
876
-
877
- submit_traj=0
878
-
879
- async def inference_traject(origin_image,sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
880
- original_size, input_size, text_refiner,focus_type,paragraph,openai_api_key,autoplay,trace_type):
881
- image_input, mask = sketcher_image['image'], sketcher_image['mask']
882
-
883
- crop_save_path=""
884
-
885
- prompt = get_sketch_prompt(mask)
886
- boxes = prompt['input_boxes']
887
- boxes = boxes[0]
888
- global submit_traj
889
- submit_traj=1
890
-
891
- controls = {'length': length,
892
- 'sentiment': sentiment,
893
- 'factuality': factuality,
894
- 'language': language}
895
-
896
- model = build_caption_anything_with_models(
897
- args,
898
- api_key="",
899
- captioner=shared_captioner,
900
- sam_model=shared_sam_model,
901
- ocr_reader=shared_ocr_reader,
902
- text_refiner=text_refiner,
903
- session_id=iface.app_id
904
- )
905
-
906
- model.setup(image_embedding, original_size, input_size, is_image_set=True)
907
-
908
- enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
909
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki,verbose=True)[0]
910
-
911
- print(trace_type)
912
-
913
- if trace_type=="Trace+Seg":
914
- input_mask = np.array(out['mask'].convert('P'))
915
- image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0 )
916
- crop_save_path=out['crop_save_path']
917
-
918
- else:
919
- image_input = Image.fromarray(np.array(origin_image))
920
- draw = ImageDraw.Draw(image_input)
921
- draw.rectangle(boxes, outline='red', width=2)
922
- cropped_image = origin_image.crop(boxes)
923
- cropped_image.save('temp.png')
924
- crop_save_path='temp.png'
925
-
926
- print("crop_svae_path",out['crop_save_path'])
927
-
928
- # Update components and states
929
- state.append((f'Box: {boxes}', None))
930
-
931
- # fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
932
- # image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
933
-
934
- prompt=generate_prompt(focus_type, paragraph, length, sentiment, factuality, language)
935
- width, height = sketcher_image['image'].size
936
- sketcher_image['mask'] = np.zeros((height, width, 4), dtype=np.uint8)
937
- sketcher_image['mask'][..., -1] = 255
938
- sketcher_image['image']=image_input
939
-
940
-
941
- if not args.disable_gpt and text_refiner:
942
- focus_info=get_image_gpt(openai_api_key,crop_save_path,prompt)
943
- if focus_info.startswith('"') and focus_info.endswith('"'):
944
- focus_info=focus_info[1:-1]
945
- focus_info=focus_info.replace('#', '')
946
- state = state + [(None, f"{focus_info}")]
947
- print("new_cap",focus_info)
948
- read_info = re.sub(r'[#[\]!*]','',focus_info)
949
- read_info = emoji.replace_emoji(read_info,replace="")
950
- print("read info",read_info)
951
-
952
- # refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask,
953
- # input_points=input_points, input_labels=input_labels)
954
- try:
955
- audio_output = await texttospeech(read_info, language,autoplay)
956
- # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
957
- return state, state,image_input,audio_output
958
-
959
-
960
- except Exception as e:
961
- state = state + [(None, f"Error during TTS prediction: {str(e)}")]
962
- print(f"Error during TTS prediction: {str(e)}")
963
- # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
964
- return state, state, image_input,audio_output
965
-
966
-
967
- else:
968
- try:
969
- audio_output = await texttospeech(focus_info, language, autoplay)
970
- # waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
971
- # return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
972
- return state, state, image_input,audio_output
973
-
974
-
975
- except Exception as e:
976
- state = state + [(None, f"Error during TTS prediction: {str(e)}")]
977
- print(f"Error during TTS prediction: {str(e)}")
978
- return state, state, image_input,audio_output
979
-
980
-
981
- def clear_chat_memory(visual_chatgpt, keep_global=False):
982
- if visual_chatgpt is not None:
983
- visual_chatgpt.memory.clear()
984
- visual_chatgpt.point_prompt = ""
985
- if keep_global:
986
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
987
- else:
988
- visual_chatgpt.current_image = None
989
- visual_chatgpt.global_prompt = ""
990
-
991
-
992
- def export_chat_log(chat_state):
993
- try:
994
- if not chat_state:
995
- return None
996
- chat_log = "\n".join(f"{entry[0]}\n{entry[1]}" for entry in chat_state if entry)
997
- print("export log...")
998
- print("chat_log", chat_log)
999
- with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
1000
- temp_file.write(chat_log.encode('utf-8'))
1001
- temp_file_path = temp_file.name
1002
- print(temp_file_path)
1003
- return temp_file_path
1004
- except Exception as e:
1005
- print(f"An error occurred while exporting the chat log: {e}")
1006
- return None
1007
-
1008
-
1009
- async def cap_everything(paragraph, visual_chatgpt,language,autoplay):
1010
-
1011
- # state = state + [(None, f"Caption Everything: {paragraph}")]
1012
- Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
1013
- AI_prompt = "Received."
1014
- visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
1015
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
1016
- # waveform_visual, audio_output=tts.predict(paragraph, input_language, input_audio, input_mic, use_mic, agree)
1017
- audio_output=await texttospeech(paragraph,language,autoplay)
1018
- return paragraph,audio_output
1019
-
1020
- def cap_everything_withoutsound(image_input, visual_chatgpt, text_refiner,paragraph):
1021
-
1022
- model = build_caption_anything_with_models(
1023
- args,
1024
- api_key="",
1025
- captioner=shared_captioner,
1026
- sam_model=shared_sam_model,
1027
- ocr_reader=shared_ocr_reader,
1028
- text_refiner=text_refiner,
1029
- session_id=iface.app_id
1030
- )
1031
- paragraph = model.inference_cap_everything(image_input, verbose=True)
1032
- # state = state + [(None, f"Caption Everything: {paragraph}")]
1033
- Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
1034
- AI_prompt = "Received."
1035
- visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
1036
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
1037
- return paragraph
1038
-
1039
-
1040
-
1041
- def get_style():
1042
- current_version = version.parse(gr.__version__)
1043
- if current_version <= version.parse('3.24.1'):
1044
- style = '''
1045
- #image_sketcher{min-height:500px}
1046
- #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
1047
- #image_upload{min-height:500px}
1048
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
1049
- .custom-language {
1050
- width: 20%;
1051
- }
1052
-
1053
- .custom-autoplay {
1054
- width: 40%;
1055
- }
1056
-
1057
- .custom-output {
1058
- width: 30%;
1059
- }
1060
-
1061
- '''
1062
- elif current_version <= version.parse('3.27'):
1063
- style = '''
1064
- #image_sketcher{min-height:500px}
1065
- #image_upload{min-height:500px}
1066
- .custom-language {
1067
- width: 20%;
1068
- }
1069
-
1070
- .custom-autoplay {
1071
- width: 40%;
1072
- }
1073
-
1074
- .custom-output {
1075
- width: 30%;
1076
- }
1077
- '''
1078
- else:
1079
- style = None
1080
-
1081
- return style
1082
-
1083
- # def handle_like_dislike(like_data, like_state, dislike_state):
1084
- # if like_data.liked:
1085
- # if like_data.index not in like_state:
1086
- # like_state.append(like_data.index)
1087
- # message = f"Liked: {like_data.value} at index {like_data.index}"
1088
- # else:
1089
- # message = "You already liked this item"
1090
- # else:
1091
- # if like_data.index not in dislike_state:
1092
- # dislike_state.append(like_data.index)
1093
- # message = f"Disliked: {like_data.value} at index {like_data.index}"
1094
- # else:
1095
- # message = "You already disliked this item"
1096
-
1097
- # return like_state, dislike_state
1098
-
1099
- async def texttospeech(text,language,autoplay):
1100
- voice=filtered_language_dict[language]
1101
- communicate = edge_tts.Communicate(text, voice)
1102
- file_path="output.wav"
1103
- await communicate.save(file_path)
1104
- with open(file_path, "rb") as audio_file:
1105
- audio_bytes = BytesIO(audio_file.read())
1106
- audio = base64.b64encode(audio_bytes.read()).decode("utf-8")
1107
- print("tts....")
1108
- audio_style = 'style="width:210px;"'
1109
- if autoplay:
1110
- audio_player = f'<audio src="data:audio/wav;base64,{audio}" controls autoplay {audio_style}></audio>'
1111
- else:
1112
- audio_player=None
1113
- # audio_player = f'<audio src="data:audio/wav;base64,{audio}" controls {audio_style}></audio>'
1114
- return audio_player
1115
-
1116
-
1117
- def create_ui():
1118
- title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
1119
- """
1120
- description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
1121
-
1122
- examples = [
1123
- ["test_images/ambass.jpg"],
1124
- ["test_images/pearl.jpg"],
1125
- ["test_images/Picture0.png"],
1126
- ["test_images/Picture1.png"],
1127
- ["test_images/Picture2.png"],
1128
- ["test_images/Picture3.png"],
1129
- ["test_images/Picture4.png"],
1130
- ["test_images/Picture5.png"],
1131
-
1132
- ]
1133
-
1134
- with gr.Blocks(
1135
- css=get_style(),
1136
- theme=gr.themes.Base()
1137
- ) as iface:
1138
  state = gr.State([])
1139
- out_state = gr.State(None)
1140
- click_state = gr.State([[], [], []])
1141
- origin_image = gr.State(None)
1142
- image_embedding = gr.State(None)
1143
- text_refiner = gr.State(None)
1144
- visual_chatgpt = gr.State(None)
1145
- original_size = gr.State(None)
1146
- input_size = gr.State(None)
1147
- paragraph = gr.State("")
1148
  aux_state = gr.State([])
1149
- click_index_state = gr.State((0, 0))
1150
- input_mask_state = gr.State(np.zeros((1, 1)))
1151
- input_points_state = gr.State([])
1152
- input_labels_state = gr.State([])
1153
- new_crop_save_path = gr.State(None)
1154
- image_input_nobackground = gr.State(None)
1155
-
1156
- gr.Markdown(title)
1157
- gr.Markdown(description)
1158
- with gr.Row(align="right", visible=False, elem_id="top_row") as top_row:
1159
- language = gr.Dropdown(
1160
- ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
1161
- value="English", label="Language", interactive=True, scale=0.2, elem_classes="custom-language"
1162
- )
1163
- auto_play = gr.Checkbox(
1164
- label="Check to autoplay audio", value=False, scale=0.4, elem_classes="custom-autoplay"
1165
- )
1166
- output_audio = gr.HTML(
1167
- label="Synthesised Audio", scale=0.3, elem_classes="custom-output"
1168
- )
1169
-
1170
-
1171
- # with gr.Row(align="right",visible=False) as language_select:
1172
- # language = gr.Dropdown(
1173
- # ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
1174
- # value="English", label="Language", interactive=True)
1175
-
1176
- # with gr.Row(align="right",visible=False) as autoplay:
1177
- # auto_play = gr.Checkbox(label="Check to autoplay audio", value=False,scale=0.4)
1178
- # output_audio = gr.HTML(label="Synthesised Audio",scale=0.6)
1179
-
1180
- with gr.Row():
1181
-
1182
- with gr.Column(scale=1.0):
1183
- with gr.Column(visible=False) as modules_not_need_gpt:
1184
- with gr.Tab("Base(GPT Power)",visible=False) as base_tab:
1185
- image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1186
- example_image = gr.Image(type="pil", interactive=False, visible=False)
1187
- with gr.Row():
1188
- name_label_base = gr.Button(value="Name: ")
1189
- artist_label_base = gr.Button(value="Artist: ")
1190
- year_label_base = gr.Button(value="Year: ")
1191
- material_label_base = gr.Button(value="Material: ")
1192
-
1193
- with gr.Tab("Click") as click_tab:
1194
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1195
- example_image = gr.Image(type="pil", interactive=False, visible=False)
1196
- with gr.Row():
1197
- name_label = gr.Button(value="Name: ")
1198
- artist_label = gr.Button(value="Artist: ")
1199
- year_label = gr.Button(value="Year: ")
1200
- material_label = gr.Button(value="Material: ")
1201
- with gr.Row(scale=1.0):
1202
- with gr.Row(scale=0.8):
1203
- focus_type = gr.Radio(
1204
- choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
1205
- value="CFV-D",
1206
- label="Information Type",
1207
- interactive=True)
1208
- with gr.Row(scale=0.2):
1209
- submit_button_click=gr.Button(value="Submit", interactive=True,variant='primary',size="sm")
1210
- with gr.Row(scale=1.0):
1211
- with gr.Row(scale=0.4):
1212
- point_prompt = gr.Radio(
1213
- choices=["Positive", "Negative"],
1214
- value="Positive",
1215
- label="Point Prompt",
1216
- interactive=True)
1217
- click_mode = gr.Radio(
1218
- choices=["Continuous", "Single"],
1219
- value="Continuous",
1220
- label="Clicking Mode",
1221
- interactive=True)
1222
- with gr.Row(scale=0.4):
1223
- clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
1224
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
1225
-
1226
- with gr.Tab("Trajectory (beta)") as traj_tab:
1227
- sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
1228
- elem_id="image_sketcher")
1229
- example_image = gr.Image(type="pil", interactive=False, visible=False)
1230
- with gr.Row():
1231
- submit_button_sketcher = gr.Button(value="Submit", interactive=True)
1232
- clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
1233
- with gr.Row(scale=1.0):
1234
- with gr.Row(scale=0.8):
1235
- focus_type_sketch = gr.Radio(
1236
- choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
1237
- value="CFV-D",
1238
- label="Information Type",
1239
- interactive=True)
1240
- Input_sketch = gr.Radio(
1241
- choices=["Trace+Seg", "Trace"],
1242
- value="Trace+Seg",
1243
- label="Trace Type",
1244
- interactive=True)
1245
-
1246
- with gr.Column(visible=False) as modules_need_gpt1:
1247
- with gr.Row(scale=1.0):
1248
- sentiment = gr.Radio(
1249
- choices=["Positive", "Natural", "Negative"],
1250
- value="Natural",
1251
- label="Sentiment",
1252
- interactive=True,
1253
- )
1254
- with gr.Row(scale=1.0):
1255
- factuality = gr.Radio(
1256
- choices=["Factual", "Imagination"],
1257
- value="Factual",
1258
- label="Factuality",
1259
- interactive=True,
1260
- )
1261
- length = gr.Slider(
1262
- minimum=10,
1263
- maximum=80,
1264
- value=10,
1265
- step=1,
1266
- interactive=True,
1267
- label="Generated Caption Length",
1268
- )
1269
- # 是否启用wiki内容整合到caption中
1270
- enable_wiki = gr.Radio(
1271
- choices=["Yes", "No"],
1272
- value="No",
1273
- label="Expert",
1274
- interactive=True)
1275
- with gr.Column(visible=True) as modules_not_need_gpt3:
1276
- gr.Examples(
1277
- examples=examples,
1278
- inputs=[example_image],
1279
- )
1280
-
1281
-
1282
-
1283
-
1284
-
1285
- with gr.Column(scale=0.5):
1286
- with gr.Column(visible=True) as module_key_input:
1287
- openai_api_key = gr.Textbox(
1288
- placeholder="Input openAI API key",
1289
- show_label=False,
1290
- label="OpenAI API Key",
1291
- lines=1,
1292
- type="password")
1293
- with gr.Row(scale=0.5):
1294
- enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
1295
- disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
1296
- variant='primary')
1297
- with gr.Column(visible=False) as module_notification_box:
1298
- notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
1299
-
1300
- with gr.Column() as modules_need_gpt0:
1301
- with gr.Column(visible=False,scale=1.0) as modules_need_gpt2:
1302
- paragraph_output = gr.Textbox(lines=16, label="Describe Everything", max_lines=16)
1303
- cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
1304
-
1305
- with gr.Column(visible=False) as modules_not_need_gpt2:
1306
- with gr.Blocks():
1307
- chatbot = gr.Chatbot(label="Chatbox", elem_classes="chatbot",likeable=True).style(height=600, scale=0.5)
1308
- with gr.Column(visible=False) as modules_need_gpt3:
1309
- chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
1310
- container=False)
1311
- with gr.Row():
1312
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
1313
- submit_button_text = gr.Button(value="Send", interactive=True, variant="primary")
1314
- upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
1315
- downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
1316
-
1317
- with gr.Row():
1318
- export_button = gr.Button(value="Export Chat Log", interactive=True, variant="primary")
1319
- with gr.Row():
1320
- chat_log_file = gr.File(label="Download Chat Log")
1321
-
1322
- # TTS interface hidden initially
1323
- with gr.Column(visible=False) as tts_interface:
1324
- input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality")
1325
- input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en")
1326
- input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav")
1327
- input_mic = gr.Audio(source="microphone", type="filepath", label="Use Microphone for Reference")
1328
- use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False)
1329
- agree = gr.Checkbox(label="Agree", value=True)
1330
- output_waveform = gr.Video(label="Waveform Visual")
1331
- # output_audio = gr.HTML(label="Synthesised Audio")
1332
-
1333
- with gr.Row():
1334
- submit_tts = gr.Button(value="Submit", interactive=True)
1335
- clear_tts = gr.Button(value="Clear", interactive=True)
1336
- ###############################################################################
1337
- ############# this part is for text to image #############
1338
- ###############################################################################
1339
-
1340
- with gr.Row(variant="panel") as text2image_model:
1341
-
1342
- with gr.Column():
1343
-
1344
- with gr.Row():
1345
- prompt = gr.Text(
1346
- label="Prompt",
1347
- show_label=False,
1348
- max_lines=1,
1349
- placeholder="Enter your prompt",
1350
- container=False,
1351
- )
1352
- run_button = gr.Button("Run", scale=0)
1353
-
1354
- with gr.Accordion("Advanced options", open=True):
1355
- num_images = gr.Slider(
1356
- label="Number of Images",
1357
- minimum=1,
1358
- maximum=4,
1359
- step=1,
1360
- value=4,
1361
- )
1362
- with gr.Row():
1363
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
1364
- negative_prompt = gr.Text(
1365
- label="Negative prompt",
1366
- max_lines=5,
1367
- lines=4,
1368
- placeholder="Enter a negative prompt",
1369
- value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
1370
- visible=True,
1371
- )
1372
- seed = gr.Slider(
1373
- label="Seed",
1374
- minimum=0,
1375
- maximum=MAX_SEED,
1376
- step=1,
1377
- value=0,
1378
- )
1379
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
1380
- with gr.Row(visible=True):
1381
- width = gr.Slider(
1382
- label="Width",
1383
- minimum=512,
1384
- maximum=MAX_IMAGE_SIZE,
1385
- step=64,
1386
- value=1024,
1387
- )
1388
- height = gr.Slider(
1389
- label="Height",
1390
- minimum=512,
1391
- maximum=MAX_IMAGE_SIZE,
1392
- step=64,
1393
- value=1024,
1394
- )
1395
- with gr.Row():
1396
- guidance_scale = gr.Slider(
1397
- label="Guidance Scale",
1398
- minimum=0.1,
1399
- maximum=6,
1400
- step=0.1,
1401
- value=3.0,
1402
- )
1403
- num_inference_steps = gr.Slider(
1404
- label="Number of inference steps",
1405
- minimum=1,
1406
- maximum=15,
1407
- step=1,
1408
- value=8,
1409
- )
1410
- with gr.Column():
1411
- result = gr.Gallery(
1412
- label="Result",
1413
- columns=[2],
1414
- rows=[2],
1415
- show_label=False,
1416
- allow_preview=True,
1417
- object_fit="contain",
1418
- height="auto",
1419
- preview=True,
1420
- show_share_button=True,
1421
- show_download_button=True
1422
- )
1423
-
1424
-
1425
-
1426
- # gr.Examples(
1427
- # examples=examples,
1428
- # inputs=prompt,
1429
- # cache_examples=False
1430
- # )
1431
-
1432
- use_negative_prompt.change(
1433
- fn=lambda x: gr.update(visible=x),
1434
- inputs=use_negative_prompt,
1435
- outputs=negative_prompt,
1436
- api_name=False,
1437
- )
1438
-
1439
- # gr.on(
1440
- # triggers=[
1441
- # prompt.submit,
1442
- # negative_prompt.submit,
1443
- # run_button.click,
1444
- # ],
1445
- # fn=generate,
1446
- # inputs=[
1447
- # prompt,
1448
- # negative_prompt,
1449
- # use_negative_prompt,
1450
- # seed,
1451
- # width,
1452
- # height,
1453
- # guidance_scale,
1454
- # num_inference_steps,
1455
- # randomize_seed,
1456
- # num_images
1457
- # ],
1458
- # outputs=[result, seed],
1459
- # api_name="run",
1460
- # )
1461
- run_button.click(
1462
- fn=generate,
1463
- inputs=[
1464
- prompt,
1465
- negative_prompt,
1466
- use_negative_prompt,
1467
- seed,
1468
- width,
1469
- height,
1470
- guidance_scale,
1471
- num_inference_steps,
1472
- randomize_seed,
1473
- num_images
1474
- ],
1475
- outputs=[result, seed]
1476
- )
1477
-
1478
- ###############################################################################
1479
- ############# above part is for text to image #############
1480
- ###############################################################################
1481
-
1482
-
1483
- ###############################################################################
1484
- # this part is for 3d generate.
1485
- ###############################################################################
1486
-
1487
- with gr.Row(variant="panel",visible=False) as d3_model:
1488
- with gr.Column():
1489
- with gr.Row():
1490
- input_image = gr.Image(
1491
- label="Input Image",
1492
- image_mode="RGBA",
1493
- sources="upload",
1494
- #width=256,
1495
- #height=256,
1496
- type="pil",
1497
- elem_id="content_image",
1498
- )
1499
- processed_image = gr.Image(
1500
- label="Processed Image",
1501
- image_mode="RGBA",
1502
- #width=256,
1503
- #height=256,
1504
- type="pil",
1505
- interactive=False
1506
- )
1507
- with gr.Row():
1508
- with gr.Group():
1509
- do_remove_background = gr.Checkbox(
1510
- label="Remove Background", value=True
1511
- )
1512
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
1513
-
1514
- sample_steps = gr.Slider(
1515
- label="Sample Steps",
1516
- minimum=30,
1517
- maximum=75,
1518
- value=75,
1519
- step=5
1520
- )
1521
-
1522
- with gr.Row():
1523
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
1524
-
1525
- with gr.Row(variant="panel"):
1526
- gr.Examples(
1527
- examples=[
1528
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
1529
- ],
1530
- inputs=[input_image],
1531
- label="Examples",
1532
- cache_examples=False,
1533
- examples_per_page=16
1534
- )
1535
-
1536
- with gr.Column():
1537
-
1538
- with gr.Row():
1539
-
1540
- with gr.Column():
1541
- mv_show_images = gr.Image(
1542
- label="Generated Multi-views",
1543
- type="pil",
1544
- width=379,
1545
- interactive=False
1546
- )
1547
-
1548
- # with gr.Column():
1549
- # output_video = gr.Video(
1550
- # label="video", format="mp4",
1551
- # width=379,
1552
- # autoplay=True,
1553
- # interactive=False
1554
- # )
1555
-
1556
- with gr.Row():
1557
- with gr.Tab("OBJ"):
1558
- output_model_obj = gr.Model3D(
1559
- label="Output Model (OBJ Format)",
1560
- interactive=False,
1561
- )
1562
- gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
1563
- with gr.Tab("GLB"):
1564
- output_model_glb = gr.Model3D(
1565
- label="Output Model (GLB Format)",
1566
- interactive=False,
1567
- )
1568
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
1569
-
1570
-
1571
-
1572
-
1573
- mv_images = gr.State()
1574
-
1575
- # chatbot.like(handle_like_dislike, inputs=[like_state, dislike_state], outputs=[like_state, dislike_state])
1576
-
1577
- submit.click(fn=check_input_image, inputs=[new_crop_save_path], outputs=[processed_image]).success(
1578
- fn=generate_mvs,
1579
- inputs=[processed_image, sample_steps, sample_seed],
1580
- outputs=[mv_images, mv_show_images]
1581
-
1582
- ).success(
1583
- fn=make3d,
1584
- inputs=[mv_images],
1585
- outputs=[output_model_obj, output_model_glb]
1586
- )
1587
-
1588
- ###############################################################################
1589
- # above part is for 3d generate.
1590
- ###############################################################################
1591
-
1592
-
1593
- def clear_tts_fields():
1594
- return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
1595
-
1596
- # submit_tts.click(
1597
- # tts.predict,
1598
- # inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree],
1599
- # outputs=[output_waveform, output_audio],
1600
- # queue=True
1601
- # )
1602
-
1603
- clear_tts.click(
1604
- clear_tts_fields,
1605
- inputs=None,
1606
- outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio],
1607
- queue=False
1608
- )
1609
-
1610
-
1611
-
1612
-
1613
- clear_button_sketcher.click(
1614
- lambda x: (x),
1615
- [origin_image],
1616
- [sketcher_input],
1617
- queue=False,
1618
- show_progress=False
1619
- )
1620
-
1621
-
1622
-
1623
-
1624
-
1625
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
1626
- outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
1627
- modules_not_need_gpt2, tts_interface,module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
1628
- enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
1629
- outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
1630
- modules_not_need_gpt,
1631
- modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
1632
- disable_chatGPT_button.click(init_wo_openai_api_key,
1633
- outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
1634
- modules_not_need_gpt,
1635
- modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
1636
-
1637
- enable_chatGPT_button.click(
1638
- lambda: (None, [], [], [[], [], []], "", "", ""),
1639
- [],
1640
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1641
- queue=False,
1642
- show_progress=False
1643
- )
1644
- openai_api_key.submit(
1645
- lambda: (None, [], [], [[], [], []], "", "", ""),
1646
- [],
1647
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1648
- queue=False,
1649
- show_progress=False
1650
- )
1651
-
1652
- cap_everything_button.click(cap_everything, [paragraph, visual_chatgpt, language,auto_play],
1653
- [paragraph_output,output_audio])
1654
-
1655
- clear_button_click.click(
1656
- lambda x: ([[], [], []], x),
1657
- [origin_image],
1658
- [click_state, image_input],
1659
- queue=False,
1660
- show_progress=False
1661
- )
1662
- clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
1663
- clear_button_image.click(
1664
- lambda: (None, [], [], [[], [], []], "", "", ""),
1665
- [],
1666
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1667
- queue=False,
1668
- show_progress=False
1669
- )
1670
- clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
1671
- clear_button_text.click(
1672
- lambda: ([], [], [[], [], [], []]),
1673
- [],
1674
- [chatbot, state, click_state],
1675
- queue=False,
1676
- show_progress=False
1677
- )
1678
- clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
1679
-
1680
- image_input.clear(
1681
- lambda: (None, [], [], [[], [], []], "", "", ""),
1682
- [],
1683
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
1684
- queue=False,
1685
- show_progress=False
1686
- )
1687
-
1688
- image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
1689
-
1690
-
1691
-
1692
-
1693
- image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key],
1694
- [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1695
- image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph])
1696
-
1697
- image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
1698
- [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1699
- image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph])
1700
- sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
1701
- [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1702
- image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph])
1703
- chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
1704
- [chatbot, state, aux_state,output_audio])
1705
- chat_input.submit(lambda: "", None, chat_input)
1706
- submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
1707
- [chatbot, state, aux_state,output_audio])
1708
- submit_button_text.click(lambda: "", None, chat_input)
1709
- example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key],
1710
- [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1711
- image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph])
1712
-
1713
- example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
1714
-
1715
- def on_click_tab_selected():
1716
- if gpt_state ==1:
1717
- print(gpt_state)
1718
- print("using gpt")
1719
- return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
1720
- else:
1721
- print("no gpt")
1722
- print("gpt_state",gpt_state)
1723
- return [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2
1724
-
1725
- def on_base_selected():
1726
- if gpt_state ==1:
1727
- print(gpt_state)
1728
- print("using gpt")
1729
- return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
1730
- else:
1731
- print("no gpt")
1732
- return [gr.update(visible=False)]*4
1733
-
1734
-
1735
- traj_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
1736
- click_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
1737
- base_tab.select(on_base_selected, outputs=[modules_need_gpt0,modules_need_gpt2,modules_not_need_gpt2,modules_need_gpt1])
1738
-
1739
-
1740
-
1741
-
1742
- image_input.select(
1743
- inference_click,
1744
- inputs=[
1745
- origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
1746
- image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
1747
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
1748
- ],
1749
- outputs=[chatbot, state, click_state, image_input, input_image, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground],
1750
- show_progress=False, queue=True
1751
- )
1752
-
1753
-
1754
- submit_button_click.click(
1755
- submit_caption,
1756
- inputs=[
1757
- state, text_refiner,length, sentiment, factuality, language,
1758
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
1759
- auto_play,paragraph,focus_type,openai_api_key,new_crop_save_path
1760
- ],
1761
- outputs=[
1762
- chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,
1763
- output_audio
1764
- ],
1765
- show_progress=True,
1766
- queue=True
1767
- )
1768
-
1769
-
1770
- submit_button_sketcher.click(
1771
- inference_traject,
1772
- inputs=[
1773
- origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
1774
- original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
1775
- ],
1776
- outputs=[chatbot, state, sketcher_input,output_audio],
1777
- show_progress=False, queue=True
1778
- )
1779
-
1780
- export_button.click(
1781
- export_chat_log,
1782
- inputs=[state],
1783
- outputs=[chat_log_file],
1784
- queue=True
1785
- )
1786
-
1787
-
1788
-
1789
-
1790
-
1791
- return iface
1792
-
1793
-
1794
- if __name__ == '__main__':
1795
- iface = create_ui()
1796
- iface.queue(concurrency_count=5, api_open=False, max_size=10)
1797
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
1
+ # Copyright (c) Microsoft
2
+ # Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
3
+
4
  import os
 
 
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import re
7
+ import uuid
8
+ from PIL import Image, ImageDraw, ImageOps
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
+ import argparse
11
+ import inspect
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ from langchain.agents.initialize import initialize_agent
14
+ from langchain.agents.tools import Tool
15
+ from langchain.chains.conversation.memory import ConversationBufferMemory
16
+ from langchain.llms.openai import OpenAIChat
17
+ import torch
18
+ from PIL import Image, ImageDraw, ImageOps
19
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # openai.api_version = '2020-11-07'
22
+ os.environ["OPENAI_API_VERSION"] = '2020-11-07'
23
 
24
+ VISUAL_CHATGPT_PREFIX = """
25
+ I want you to act as an art connoisseur, providing in-depth and insightful analysis on various artworks. Your responses should reflect a deep understanding of art history, techniques, and cultural contexts, offering users a rich and nuanced perspective.
 
 
 
 
 
 
26
 
27
+ You can engage in natural-sounding conversations, generate human-like text based on input, and provide relevant, coherent responses on art-related topics."""
28
 
29
 
30
+ # TOOLS:
31
+ # ------
32
 
33
+ # Visual ChatGPT has access to the following tools:"""
 
 
34
 
35
+ VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
36
 
37
+ "Thought: Do I need to use a tool? Yes
38
+ Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
39
+ Action Input: the input to the action
40
+ Observation: the result of the action"
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
 
 
 
 
 
43
 
44
+ "Thought: Do I need to use a tool? No
45
+ {ai_prefix}: [your response here]"
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  """
48
 
49
+ VISUAL_CHATGPT_SUFFIX = """
50
+ Begin Chatting!
51
+
52
+ Previous conversation history:
53
+ {chat_history}
54
+
55
+ New input: {input}
56
+ As a language model, you must repeatly to use VQA tools to observe images. You response should be consistent with the outputs of the VQA tool instead of imagination. Do not repeat asking the same question.
57
+
58
+ Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
59
+
60
+ os.makedirs('chat_image', exist_ok=True)
61
+
62
+
63
+ def prompts(name, description):
64
+ def decorator(func):
65
+ func.name = name
66
+ func.description = description
67
+ return func
68
+ return decorator
69
+
70
+ def cut_dialogue_history(history_memory, keep_last_n_words=500):
71
+ if history_memory is None or len(history_memory) == 0:
72
+ return history_memory
73
+ tokens = history_memory.split()
74
+ n_tokens = len(tokens)
75
+ print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
76
+ if n_tokens < keep_last_n_words:
77
+ return history_memory
78
+ paragraphs = history_memory.split('\n')
79
+ last_n_tokens = n_tokens
80
+ while last_n_tokens >= keep_last_n_words:
81
+ last_n_tokens -= len(paragraphs[0].split(' '))
82
+ paragraphs = paragraphs[1:]
83
+ return '\n' + '\n'.join(paragraphs)
84
+
85
+ def get_new_image_name(folder='chat_image', func_name="update"):
86
+ this_new_uuid = str(uuid.uuid4())[:8]
87
+ new_file_name = f'{func_name}_{this_new_uuid}.png'
88
+ return os.path.join(folder, new_file_name)
89
+
90
+ class VisualQuestionAnswering:
91
+ def __init__(self, device):
92
+ print(f"Initializing VisualQuestionAnswering to {device}")
93
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
94
+ self.device = device
95
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
96
+ self.model = BlipForQuestionAnswering.from_pretrained(
97
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
98
+ # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
99
+ # self.model = BlipForQuestionAnswering.from_pretrained(
100
+ # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
101
+
102
+ @prompts(name="Answer Question About The Image",
103
+ description="VQA tool is useful when you need an answer for a question based on an image. "
104
+ "like: what is the color of an object, how many cats in this figure, where is the child sitting, what does the cat doing, why is he laughing."
105
+ "The input to this tool should be a comma separated string of two, representing the image path and the question.")
106
+ def inference(self, inputs):
107
+ image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
108
+ raw_image = Image.open(image_path).convert('RGB')
109
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
110
+ out = self.model.generate(**inputs)
111
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
112
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
113
+ f"Output Answer: {answer}")
114
+ return answer
115
+
116
+ def build_chatbot_tools(load_dict):
117
+ print(f"Initializing ChatBot, load_dict={load_dict}")
118
+ models = {}
119
+ # Load Basic Foundation Models
120
+ for class_name, device in load_dict.items():
121
+ models[class_name] = globals()[class_name](device=device)
122
+
123
+ # Load Template Foundation Models
124
+ for class_name, module in globals().items():
125
+ if getattr(module, 'template_model', False):
126
+ template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
127
+ loaded_names = set([type(e).__name__ for e in models.values()])
128
+ if template_required_names.issubset(loaded_names):
129
+ models[class_name] = globals()[class_name](
130
+ **{name: models[name] for name in template_required_names})
131
 
132
+ tools = []
133
+ for instance in models.values():
134
+ for e in dir(instance):
135
+ if e.startswith('inference'):
136
+ func = getattr(instance, e)
137
+ tools.append(Tool(name=func.name, description=func.description, func=func))
138
+ return tools
139
+
140
+ class ConversationBot:
141
+ def __init__(self, tools, api_key=""):
142
+ # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
143
+ llm = OpenAIChat(model_name="gpt-4o", temperature=0.7, openai_api_key=api_key, model_kwargs={"api_version": "2020-11-07"})
144
+ self.llm = llm
145
+ self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
146
+ self.tools = tools
147
+ self.current_image = None
148
+ self.point_prompt = ""
149
+ self.global_prompt = ""
150
+ self.agent = initialize_agent(
151
+ self.tools,
152
+ self.llm,
153
+ agent="conversational-react-description",
154
+ verbose=True,
155
+ memory=self.memory,
156
+ return_intermediate_steps=True,
157
+ agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
158
+ 'suffix': VISUAL_CHATGPT_SUFFIX}, )
159
+
160
+ def constructe_intermediate_steps(self, agent_res):
161
+ ans = []
162
+ for action, output in agent_res:
163
+ if hasattr(action, "tool_input"):
164
+ use_tool = "Yes"
165
+ act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
166
+ else:
167
+ use_tool = "No"
168
+ act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
169
+ act= list(map(lambda x: x.replace('\n', '<br>'), act))
170
+ ans.append(act)
171
+ return ans
172
+
173
+ def run_text(self, text, state, aux_state):
174
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
175
+ if self.point_prompt != "":
176
+ Human_prompt = f'\nHuman: {self.point_prompt}\n'
177
+ AI_prompt = 'Ok'
178
+ self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
179
+ self.point_prompt = ""
180
+ res = self.agent({"input": text})
181
+ res['output'] = res['output'].replace("\\", "/")
182
+ response = re.sub('(chat_image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
183
+ state = state + [(text, response)]
184
+
185
+ aux_state = aux_state + [(f"User Input: {text}", None)]
186
+ aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
187
+ print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
188
+ f"Current Memory: {self.agent.memory.buffer}\n"
189
+ f"Aux state: {aux_state}\n"
190
+ )
191
+ return state, state, aux_state, aux_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
+ if __name__ == '__main__':
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
197
+ parser.add_argument('--port', type=int, default=1015)
198
+
199
+ args = parser.parse_args()
200
+ load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
201
+ tools = build_chatbot_tools(load_dict)
202
+ bot = ConversationBot(tools)
203
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
204
+ with gr.Row():
205
+ chatbot = gr.Chatbot(elem_id="chatbot", label="CATchat").style(height=1000,scale=0.5)
206
+ auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  state = gr.State([])
 
 
 
 
 
 
 
 
 
208
  aux_state = gr.State([])
209
+ with gr.Row():
210
+ with gr.Column(scale=0.7):
211
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
212
+ container=False)
213
+ with gr.Column(scale=0.15, min_width=0):
214
+ clear = gr.Button("Clear")
215
+ with gr.Column(scale=0.15, min_width=0):
216
+ btn = gr.UploadButton("Upload", file_types=["image"])
217
+
218
+ txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
219
+ txt.submit(lambda: "", None, txt)
220
+ btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
221
+ clear.click(bot.memory.clear)
222
+ clear.click(lambda: [], None, chatbot)
223
+ clear.click(lambda: [], None, auxwindow)
224
+ clear.click(lambda: [], None, state)
225
+ clear.click(lambda: [], None, aux_state)
226
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)