huz-relay commited on
Commit
e5307bd
·
1 Parent(s): 69c7cd1

Replace model with idefics2forconditionalgeneration

Browse files
Files changed (1) hide show
  1. handler.py +7 -4
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Any, Dict, List
2
- from transformers import Idefics2Processor, Idefics2Model
3
  import torch
4
 
5
 
@@ -8,8 +8,9 @@ class EndpointHandler:
8
  # Preload all the elements you are going to need at inference.
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
  self.processor = Idefics2Processor.from_pretrained(path)
11
- self.model = Idefics2Model.from_pretrained(path)
12
  self.model.to(self.device)
 
13
 
14
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
  """
@@ -25,11 +26,13 @@ class EndpointHandler:
25
  # process image
26
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
27
  print("inputs reached")
28
- output = self.model.forward(input_ids=inputs.input_ids)
29
  print("generated")
30
 
31
  # run prediction
32
- generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
 
 
33
  print("decoded")
34
 
35
  # decode output
 
1
  from typing import Any, Dict, List
2
+ from transformers import Idefics2Processor, Idefics2ForConditionalGeneration
3
  import torch
4
 
5
 
 
8
  # Preload all the elements you are going to need at inference.
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
  self.processor = Idefics2Processor.from_pretrained(path)
11
+ self.model = Idefics2ForConditionalGeneration.from_pretrained(path)
12
  self.model.to(self.device)
13
+ print("Initialisation finished!")
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
 
26
  # process image
27
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
28
  print("inputs reached")
29
+ generated_ids = self.model.generate(**inputs)
30
  print("generated")
31
 
32
  # run prediction
33
+ generated_text = self.processor.batch_decode(
34
+ generated_ids, skip_special_tokens=True
35
+ )
36
  print("decoded")
37
 
38
  # decode output