GuanshuoXu
commited on
Commit
•
b7d2c14
1
Parent(s):
bcb9d7b
add model.ocr
Browse files- modelling_h2ovl_chat.py +134 -0
modelling_h2ovl_chat.py
CHANGED
@@ -13,6 +13,7 @@ import transformers
|
|
13 |
from .conversation import get_conv_template
|
14 |
from .configuration_h2ovl_chat import H2OVLChatConfig
|
15 |
from .image_process import load_single_image, load_multi_images
|
|
|
16 |
|
17 |
logger = logging.get_logger(__name__)
|
18 |
|
@@ -338,3 +339,136 @@ class H2OVLChatModel(PreTrainedModel):
|
|
338 |
)
|
339 |
|
340 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from .conversation import get_conv_template
|
14 |
from .configuration_h2ovl_chat import H2OVLChatConfig
|
15 |
from .image_process import load_single_image, load_multi_images
|
16 |
+
import re
|
17 |
|
18 |
logger = logging.get_logger(__name__)
|
19 |
|
|
|
339 |
)
|
340 |
|
341 |
return outputs
|
342 |
+
|
343 |
+
def ocr(self, tokenizer, image_files, question, generation_config , max_tiles=6, history=None, return_history=False,
|
344 |
+
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
|
345 |
+
verbose=False):
|
346 |
+
|
347 |
+
from transformers import LogitsProcessor
|
348 |
+
class SuppressConsecutiveSpacesLogitsProcessor(LogitsProcessor):
|
349 |
+
def __init__(self, tokenizer):
|
350 |
+
self.tokenizer = tokenizer
|
351 |
+
def __call__(self, input_ids, scores):
|
352 |
+
logits = scores[-1].squeeze()
|
353 |
+
_, topk_indices = torch.topk(logits, 30)
|
354 |
+
if input_ids.shape[1] > 1:
|
355 |
+
if len(self.tokenizer.decode(input_ids[0, -1]).strip()) == 0 and topk_indices[0] == input_ids[0, -1]:
|
356 |
+
for i in range(len(topk_indices)):
|
357 |
+
if len(self.tokenizer.decode(topk_indices[i]).strip()) == 0:
|
358 |
+
scores[0, topk_indices[i]] = -99999999.
|
359 |
+
else:
|
360 |
+
break
|
361 |
+
return scores
|
362 |
+
|
363 |
+
if image_files:
|
364 |
+
if isinstance(image_files, list):
|
365 |
+
pixel_values, num_patches_list = load_multi_images(image_files, max_num=max_tiles) # Load multiple images
|
366 |
+
else:
|
367 |
+
pixel_values, num_patches_list = load_single_image(image_files, max_num=max_tiles, msac=self.use_msac) # Load single image
|
368 |
+
else:
|
369 |
+
pixel_values = None
|
370 |
+
num_patches_list = []
|
371 |
+
|
372 |
+
|
373 |
+
if history is None and pixel_values is not None and '<image>' not in question:
|
374 |
+
question = '<image>\n' + question
|
375 |
+
|
376 |
+
if num_patches_list is None:
|
377 |
+
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
|
378 |
+
|
379 |
+
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
|
380 |
+
|
381 |
+
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
382 |
+
self.img_context_token_id = img_context_token_id
|
383 |
+
|
384 |
+
template = get_conv_template(self.template)
|
385 |
+
template.system_message = self.system_message
|
386 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
|
387 |
+
|
388 |
+
space_suppressor = SuppressConsecutiveSpacesLogitsProcessor(tokenizer)
|
389 |
+
|
390 |
+
history = [] if history is None else history
|
391 |
+
for (old_question, old_answer) in history:
|
392 |
+
template.append_message(template.roles[0], old_question)
|
393 |
+
template.append_message(template.roles[1], old_answer)
|
394 |
+
template.append_message(template.roles[0], question)
|
395 |
+
template.append_message(template.roles[1], None)
|
396 |
+
query = template.get_prompt()
|
397 |
+
|
398 |
+
if verbose and pixel_values is not None:
|
399 |
+
image_bs = pixel_values.shape[0]
|
400 |
+
print(f'dynamic ViT batch size: {image_bs}')
|
401 |
+
|
402 |
+
for num_patches in num_patches_list:
|
403 |
+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
404 |
+
query = query.replace('<image>', image_tokens, 1)
|
405 |
+
|
406 |
+
model_inputs = tokenizer(query, return_tensors='pt')
|
407 |
+
input_ids = model_inputs['input_ids'].cuda()
|
408 |
+
attention_mask = model_inputs['attention_mask'].cuda()
|
409 |
+
generation_config['eos_token_id'] = eos_token_id
|
410 |
+
generation_output = self.generate_ocr(
|
411 |
+
space_suppressor=space_suppressor,
|
412 |
+
pixel_values=pixel_values,
|
413 |
+
input_ids=input_ids,
|
414 |
+
attention_mask=attention_mask,
|
415 |
+
**generation_config
|
416 |
+
)
|
417 |
+
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
418 |
+
response = response.split(template.sep)[0].strip()
|
419 |
+
response = re.sub(' +', ' ', response)
|
420 |
+
history.append((question, response))
|
421 |
+
if return_history:
|
422 |
+
return response, history
|
423 |
+
else:
|
424 |
+
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
|
425 |
+
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
|
426 |
+
if verbose:
|
427 |
+
print(query_to_print, response)
|
428 |
+
return response
|
429 |
+
|
430 |
+
@torch.no_grad()
|
431 |
+
def generate_ocr(
|
432 |
+
self,
|
433 |
+
space_suppressor,
|
434 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
435 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
436 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
437 |
+
visual_features: Optional[torch.FloatTensor] = None,
|
438 |
+
generation_config: Optional[GenerationConfig] = None,
|
439 |
+
output_hidden_states: Optional[bool] = None,
|
440 |
+
return_dict: Optional[bool] = None,
|
441 |
+
**generate_kwargs,
|
442 |
+
) -> torch.LongTensor:
|
443 |
+
|
444 |
+
assert self.img_context_token_id is not None
|
445 |
+
if pixel_values is not None:
|
446 |
+
if visual_features is not None:
|
447 |
+
vit_embeds = visual_features
|
448 |
+
else:
|
449 |
+
vit_embeds = self.extract_feature(pixel_values)
|
450 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
451 |
+
B, N, C = input_embeds.shape
|
452 |
+
input_embeds = input_embeds.reshape(B * N, C)
|
453 |
+
|
454 |
+
input_ids = input_ids.reshape(B * N)
|
455 |
+
selected = (input_ids == self.img_context_token_id)
|
456 |
+
assert selected.sum() != 0
|
457 |
+
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
458 |
+
|
459 |
+
input_embeds = input_embeds.reshape(B, N, C)
|
460 |
+
else:
|
461 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
462 |
+
|
463 |
+
outputs = self.language_model.generate(
|
464 |
+
logits_processor=[space_suppressor],
|
465 |
+
inputs_embeds=input_embeds,
|
466 |
+
attention_mask=attention_mask,
|
467 |
+
generation_config=generation_config,
|
468 |
+
output_hidden_states=output_hidden_states,
|
469 |
+
return_dict=return_dict,
|
470 |
+
use_cache=True,
|
471 |
+
**generate_kwargs,
|
472 |
+
)
|
473 |
+
|
474 |
+
return outputs
|