bluelike commited on
Commit
75993f4
1 Parent(s): e9d154d

Update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +9 -6
modeling_qwen.py CHANGED
@@ -564,7 +564,13 @@ class QWenModel(QWenPreTrainedModel):
564
 
565
  images = self.visual.encode(images)
566
  assert images.shape[0] == len(images)
 
 
 
 
 
567
  else:
 
568
  images = None
569
 
570
  output_attentions = (
@@ -623,11 +629,6 @@ class QWenModel(QWenPreTrainedModel):
623
 
624
  if inputs_embeds is None:
625
  inputs_embeds = self.wte(input_ids)
626
- if self.training and images == None: ### Compatible with plain text data training
627
- fake_images=torch.zeros(1,3,224,224).to(
628
- dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
629
- image_embeds = self.visual(fake_images)
630
- inputs_embeds = inputs_embeds + image_embeds.mean()*0
631
 
632
  if batch_size <= 0:
633
  raise ValueError("batch_size has to be defined and > 0")
@@ -657,7 +658,9 @@ class QWenModel(QWenPreTrainedModel):
657
  rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
658
 
659
  hidden_states = self.drop(hidden_states).clone()
660
- if images is not None:
 
 
661
  for idx, (i, a, b) in enumerate(img_pos):
662
  hidden_states[i][a + 1 : b] = images[idx]
663
  output_shape = input_shape + (hidden_states.size(-1),)
 
564
 
565
  images = self.visual.encode(images)
566
  assert images.shape[0] == len(images)
567
+ fake_images = None
568
+ elif self.training:
569
+ fake_images=torch.zeros(1,3,224,224).to(
570
+ dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
571
+ images = self.visual(fake_images)
572
  else:
573
+ fake_images = None
574
  images = None
575
 
576
  output_attentions = (
 
629
 
630
  if inputs_embeds is None:
631
  inputs_embeds = self.wte(input_ids)
 
 
 
 
 
632
 
633
  if batch_size <= 0:
634
  raise ValueError("batch_size has to be defined and > 0")
 
658
  rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
659
 
660
  hidden_states = self.drop(hidden_states).clone()
661
+ if fake_images is not None:
662
+ hidden_states = hidden_states + images.mean()*0
663
+ elif images is not None:
664
  for idx, (i, a, b) in enumerate(img_pos):
665
  hidden_states[i][a + 1 : b] = images[idx]
666
  output_shape = input_shape + (hidden_states.size(-1),)