Update modeling_qwen.py
Browse files- modeling_qwen.py +9 -6
modeling_qwen.py
CHANGED
@@ -564,7 +564,13 @@ class QWenModel(QWenPreTrainedModel):
|
|
564 |
|
565 |
images = self.visual.encode(images)
|
566 |
assert images.shape[0] == len(images)
|
|
|
|
|
|
|
|
|
|
|
567 |
else:
|
|
|
568 |
images = None
|
569 |
|
570 |
output_attentions = (
|
@@ -623,11 +629,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
623 |
|
624 |
if inputs_embeds is None:
|
625 |
inputs_embeds = self.wte(input_ids)
|
626 |
-
if self.training and images == None: ### Compatible with plain text data training
|
627 |
-
fake_images=torch.zeros(1,3,224,224).to(
|
628 |
-
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
629 |
-
image_embeds = self.visual(fake_images)
|
630 |
-
inputs_embeds = inputs_embeds + image_embeds.mean()*0
|
631 |
|
632 |
if batch_size <= 0:
|
633 |
raise ValueError("batch_size has to be defined and > 0")
|
@@ -657,7 +658,9 @@ class QWenModel(QWenPreTrainedModel):
|
|
657 |
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
|
658 |
|
659 |
hidden_states = self.drop(hidden_states).clone()
|
660 |
-
if
|
|
|
|
|
661 |
for idx, (i, a, b) in enumerate(img_pos):
|
662 |
hidden_states[i][a + 1 : b] = images[idx]
|
663 |
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
564 |
|
565 |
images = self.visual.encode(images)
|
566 |
assert images.shape[0] == len(images)
|
567 |
+
fake_images = None
|
568 |
+
elif self.training:
|
569 |
+
fake_images=torch.zeros(1,3,224,224).to(
|
570 |
+
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
571 |
+
images = self.visual(fake_images)
|
572 |
else:
|
573 |
+
fake_images = None
|
574 |
images = None
|
575 |
|
576 |
output_attentions = (
|
|
|
629 |
|
630 |
if inputs_embeds is None:
|
631 |
inputs_embeds = self.wte(input_ids)
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
if batch_size <= 0:
|
634 |
raise ValueError("batch_size has to be defined and > 0")
|
|
|
658 |
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
|
659 |
|
660 |
hidden_states = self.drop(hidden_states).clone()
|
661 |
+
if fake_images is not None:
|
662 |
+
hidden_states = hidden_states + images.mean()*0
|
663 |
+
elif images is not None:
|
664 |
for idx, (i, a, b) in enumerate(img_pos):
|
665 |
hidden_states[i][a + 1 : b] = images[idx]
|
666 |
output_shape = input_shape + (hidden_states.size(-1),)
|