fix attention mask warning
Browse filesi literally cannot test this code but normally it should work
- modeling_llamavision.py +10 -1
modeling_llamavision.py
CHANGED
@@ -105,8 +105,17 @@ class Llamavision(PreTrainedModel):
|
|
105 |
|
106 |
with torch.no_grad():
|
107 |
inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
output_ids = self.text_model.generate(
|
109 |
-
inputs_embeds=inputs_embeds,
|
|
|
|
|
110 |
)
|
111 |
|
112 |
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
|
|
105 |
|
106 |
with torch.no_grad():
|
107 |
inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
|
108 |
+
|
109 |
+
attention_mask = torch.ones(
|
110 |
+
inputs_embeds.shape[:2],
|
111 |
+
dtype=torch.long,
|
112 |
+
device=inputs_embeds.device
|
113 |
+
)
|
114 |
+
|
115 |
output_ids = self.text_model.generate(
|
116 |
+
inputs_embeds=inputs_embeds,
|
117 |
+
attention_mask=attention_mask,
|
118 |
+
**generate_config
|
119 |
)
|
120 |
|
121 |
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|