jbilcke-hf HF staff commited on
Commit
bcdcfae
·
verified ·
1 Parent(s): 6cf9e6f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -27
handler.py CHANGED
@@ -31,26 +31,27 @@ class EndpointHandler:
31
  timestep_spacing="trailing"
32
  )
33
 
34
- # Initialize video-to-video pipeline
35
- self.pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
36
- path or "jbilcke-hf/CogVideoX-Fun-V1.5-5b-for-InferenceEndpoints",
37
- transformer=self.pipe.transformer,
38
- vae=self.pipe.vae,
39
- scheduler=self.pipe.scheduler,
40
- tokenizer=self.pipe.tokenizer,
41
- text_encoder=self.pipe.text_encoder,
42
- torch_dtype=torch.bfloat16
43
- ).to("cuda")
44
-
45
- # Initialize image-to-video pipeline
46
- self.pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
47
- path or "THUDM/CogVideoX1.5-5B-I2V",
48
- vae=self.pipe.vae,
49
- scheduler=self.pipe.scheduler,
50
- tokenizer=self.pipe.tokenizer,
51
- text_encoder=self.pipe.text_encoder,
52
- torch_dtype=torch.bfloat16
53
- ).to("cuda")
 
54
 
55
  def _decode_base64_to_image(self, base64_string: str) -> Image.Image:
56
  """Convert base64 string to PIL Image."""
@@ -101,16 +102,19 @@ class EndpointHandler:
101
  input_image = self._decode_base64_to_image(data["image"])
102
  input_image = input_image.resize((720, 480)) # Resize as per example
103
  image = load_image(input_image)
104
- video_frames = self.pipe_image(
105
- image=image,
106
- **generation_kwargs
107
- ).frames[0]
 
 
 
108
 
109
  elif "video" in data:
110
  # Video to video generation
111
  # TODO: Implement video loading from base64
112
  # For now, returning error
113
- return {"error": "Video to video generation not yet implemented"}
114
 
115
  else:
116
  # Text to video generation
@@ -128,7 +132,7 @@ class EndpointHandler:
128
  """Cleanup the model and free GPU memory."""
129
  # Move models to CPU to free GPU memory
130
  self.pipe.to("cpu")
131
- self.pipe_video.to("cpu")
132
- self.pipe_image.to("cpu")
133
  # Clear CUDA cache
134
  torch.cuda.empty_cache()
 
31
  timestep_spacing="trailing"
32
  )
33
 
34
+ # those two pipelines - generated by Claude - are interesting, but loading it all at once is too much.
35
+ # # Initialize video-to-video pipeline
36
+ # self.pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
37
+ # path or "jbilcke-hf/CogVideoX-Fun-V1.5-5b-for-InferenceEndpoints",
38
+ # transformer=self.pipe.transformer,
39
+ # vae=self.pipe.vae,
40
+ # scheduler=self.pipe.scheduler,
41
+ # tokenizer=self.pipe.tokenizer,
42
+ # text_encoder=self.pipe.text_encoder,
43
+ # torch_dtype=torch.bfloat16
44
+ # ).to("cuda")
45
+ #
46
+ # # Initialize image-to-video pipeline
47
+ # self.pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
48
+ # path or "THUDM/CogVideoX1.5-5B-I2V",
49
+ # vae=self.pipe.vae,
50
+ # scheduler=self.pipe.scheduler,
51
+ # tokenizer=self.pipe.tokenizer,
52
+ # text_encoder=self.pipe.text_encoder,
53
+ # torch_dtype=torch.bfloat16
54
+ # ).to("cuda")
55
 
56
  def _decode_base64_to_image(self, base64_string: str) -> Image.Image:
57
  """Convert base64 string to PIL Image."""
 
102
  input_image = self._decode_base64_to_image(data["image"])
103
  input_image = input_image.resize((720, 480)) # Resize as per example
104
  image = load_image(input_image)
105
+
106
+ #raise ValueError("image to video isn't supported yet (takes up too much RAM right now)")
107
+ return {"error": "Image to video generation not yet supported"}
108
+ #video_frames = self.pipe_image(
109
+ # image=image,
110
+ # **generation_kwargs
111
+ #).frames[0]
112
 
113
  elif "video" in data:
114
  # Video to video generation
115
  # TODO: Implement video loading from base64
116
  # For now, returning error
117
+ return {"error": "Video to video generation not yet supported"}
118
 
119
  else:
120
  # Text to video generation
 
132
  """Cleanup the model and free GPU memory."""
133
  # Move models to CPU to free GPU memory
134
  self.pipe.to("cpu")
135
+ #self.pipe_video.to("cpu")
136
+ #self.pipe_image.to("cpu")
137
  # Clear CUDA cache
138
  torch.cuda.empty_cache()