GuanshuoXu commited on
Commit
b7d2c14
1 Parent(s): bcb9d7b

add model.ocr

Browse files
Files changed (1) hide show
  1. modelling_h2ovl_chat.py +134 -0
modelling_h2ovl_chat.py CHANGED
@@ -13,6 +13,7 @@ import transformers
13
  from .conversation import get_conv_template
14
  from .configuration_h2ovl_chat import H2OVLChatConfig
15
  from .image_process import load_single_image, load_multi_images
 
16
 
17
  logger = logging.get_logger(__name__)
18
 
@@ -338,3 +339,136 @@ class H2OVLChatModel(PreTrainedModel):
338
  )
339
 
340
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from .conversation import get_conv_template
14
  from .configuration_h2ovl_chat import H2OVLChatConfig
15
  from .image_process import load_single_image, load_multi_images
16
+ import re
17
 
18
  logger = logging.get_logger(__name__)
19
 
 
339
  )
340
 
341
  return outputs
342
+
343
+ def ocr(self, tokenizer, image_files, question, generation_config , max_tiles=6, history=None, return_history=False,
344
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
345
+ verbose=False):
346
+
347
+ from transformers import LogitsProcessor
348
+ class SuppressConsecutiveSpacesLogitsProcessor(LogitsProcessor):
349
+ def __init__(self, tokenizer):
350
+ self.tokenizer = tokenizer
351
+ def __call__(self, input_ids, scores):
352
+ logits = scores[-1].squeeze()
353
+ _, topk_indices = torch.topk(logits, 30)
354
+ if input_ids.shape[1] > 1:
355
+ if len(self.tokenizer.decode(input_ids[0, -1]).strip()) == 0 and topk_indices[0] == input_ids[0, -1]:
356
+ for i in range(len(topk_indices)):
357
+ if len(self.tokenizer.decode(topk_indices[i]).strip()) == 0:
358
+ scores[0, topk_indices[i]] = -99999999.
359
+ else:
360
+ break
361
+ return scores
362
+
363
+ if image_files:
364
+ if isinstance(image_files, list):
365
+ pixel_values, num_patches_list = load_multi_images(image_files, max_num=max_tiles) # Load multiple images
366
+ else:
367
+ pixel_values, num_patches_list = load_single_image(image_files, max_num=max_tiles, msac=self.use_msac) # Load single image
368
+ else:
369
+ pixel_values = None
370
+ num_patches_list = []
371
+
372
+
373
+ if history is None and pixel_values is not None and '<image>' not in question:
374
+ question = '<image>\n' + question
375
+
376
+ if num_patches_list is None:
377
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
378
+
379
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
380
+
381
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
382
+ self.img_context_token_id = img_context_token_id
383
+
384
+ template = get_conv_template(self.template)
385
+ template.system_message = self.system_message
386
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
387
+
388
+ space_suppressor = SuppressConsecutiveSpacesLogitsProcessor(tokenizer)
389
+
390
+ history = [] if history is None else history
391
+ for (old_question, old_answer) in history:
392
+ template.append_message(template.roles[0], old_question)
393
+ template.append_message(template.roles[1], old_answer)
394
+ template.append_message(template.roles[0], question)
395
+ template.append_message(template.roles[1], None)
396
+ query = template.get_prompt()
397
+
398
+ if verbose and pixel_values is not None:
399
+ image_bs = pixel_values.shape[0]
400
+ print(f'dynamic ViT batch size: {image_bs}')
401
+
402
+ for num_patches in num_patches_list:
403
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
404
+ query = query.replace('<image>', image_tokens, 1)
405
+
406
+ model_inputs = tokenizer(query, return_tensors='pt')
407
+ input_ids = model_inputs['input_ids'].cuda()
408
+ attention_mask = model_inputs['attention_mask'].cuda()
409
+ generation_config['eos_token_id'] = eos_token_id
410
+ generation_output = self.generate_ocr(
411
+ space_suppressor=space_suppressor,
412
+ pixel_values=pixel_values,
413
+ input_ids=input_ids,
414
+ attention_mask=attention_mask,
415
+ **generation_config
416
+ )
417
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
418
+ response = response.split(template.sep)[0].strip()
419
+ response = re.sub(' +', ' ', response)
420
+ history.append((question, response))
421
+ if return_history:
422
+ return response, history
423
+ else:
424
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
425
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
426
+ if verbose:
427
+ print(query_to_print, response)
428
+ return response
429
+
430
+ @torch.no_grad()
431
+ def generate_ocr(
432
+ self,
433
+ space_suppressor,
434
+ pixel_values: Optional[torch.FloatTensor] = None,
435
+ input_ids: Optional[torch.FloatTensor] = None,
436
+ attention_mask: Optional[torch.LongTensor] = None,
437
+ visual_features: Optional[torch.FloatTensor] = None,
438
+ generation_config: Optional[GenerationConfig] = None,
439
+ output_hidden_states: Optional[bool] = None,
440
+ return_dict: Optional[bool] = None,
441
+ **generate_kwargs,
442
+ ) -> torch.LongTensor:
443
+
444
+ assert self.img_context_token_id is not None
445
+ if pixel_values is not None:
446
+ if visual_features is not None:
447
+ vit_embeds = visual_features
448
+ else:
449
+ vit_embeds = self.extract_feature(pixel_values)
450
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
451
+ B, N, C = input_embeds.shape
452
+ input_embeds = input_embeds.reshape(B * N, C)
453
+
454
+ input_ids = input_ids.reshape(B * N)
455
+ selected = (input_ids == self.img_context_token_id)
456
+ assert selected.sum() != 0
457
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
458
+
459
+ input_embeds = input_embeds.reshape(B, N, C)
460
+ else:
461
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
462
+
463
+ outputs = self.language_model.generate(
464
+ logits_processor=[space_suppressor],
465
+ inputs_embeds=input_embeds,
466
+ attention_mask=attention_mask,
467
+ generation_config=generation_config,
468
+ output_hidden_states=output_hidden_states,
469
+ return_dict=return_dict,
470
+ use_cache=True,
471
+ **generate_kwargs,
472
+ )
473
+
474
+ return outputs