jbilcke-hf HF staff commited on
Commit
030159f
·
verified ·
1 Parent(s): bc67239

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +134 -0
handler.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Union, List
2
+ import torch
3
+ from diffusers import (
4
+ CogVideoXPipeline,
5
+ CogVideoXDPMScheduler,
6
+ CogVideoXVideoToVideoPipeline,
7
+ CogVideoXImageToVideoPipeline
8
+ )
9
+ from diffusers.utils import load_video, load_image
10
+ from PIL import Image
11
+ import base64
12
+ import io
13
+ import numpy as np
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path: str = ""):
17
+ """Initialize the CogVideoX pipeline.
18
+
19
+ Args:
20
+ path (str): Path to the model weights
21
+ """
22
+ # Initialize pipeline with bfloat16 for optimal performance as recommended in docs
23
+ self.pipe = CogVideoXPipeline.from_pretrained(
24
+ path or "THUDM/CogVideoX-5b",
25
+ torch_dtype=torch.bfloat16
26
+ ).to("cuda")
27
+
28
+ # Set up the scheduler with trailing timesteps as shown in example
29
+ self.pipe.scheduler = CogVideoXDPMScheduler.from_config(
30
+ self.pipe.scheduler.config,
31
+ timestep_spacing="trailing"
32
+ )
33
+
34
+ # Initialize video-to-video pipeline
35
+ self.pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
36
+ path or "THUDM/CogVideoX-5b",
37
+ transformer=self.pipe.transformer,
38
+ vae=self.pipe.vae,
39
+ scheduler=self.pipe.scheduler,
40
+ tokenizer=self.pipe.tokenizer,
41
+ text_encoder=self.pipe.text_encoder,
42
+ torch_dtype=torch.bfloat16
43
+ ).to("cuda")
44
+
45
+ # Initialize image-to-video pipeline
46
+ self.pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
47
+ path or "THUDM/CogVideoX-5b-I2V",
48
+ vae=self.pipe.vae,
49
+ scheduler=self.pipe.scheduler,
50
+ tokenizer=self.pipe.tokenizer,
51
+ text_encoder=self.pipe.text_encoder,
52
+ torch_dtype=torch.bfloat16
53
+ ).to("cuda")
54
+
55
+ def _decode_base64_to_image(self, base64_string: str) -> Image.Image:
56
+ """Convert base64 string to PIL Image."""
57
+ image_data = base64.b64decode(base64_string)
58
+ image = Image.open(io.BytesIO(image_data))
59
+ return image
60
+
61
+ def _encode_video_to_base64(self, video_frames: List[np.ndarray]) -> str:
62
+ """Convert video frames to base64 string."""
63
+ # Convert frames to a video file in memory
64
+ import imageio
65
+ output_bytes = io.BytesIO()
66
+ imageio.mimsave(output_bytes, video_frames, format='mp4', fps=8)
67
+ return base64.b64encode(output_bytes.getvalue()).decode('utf-8')
68
+
69
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
70
+ """Process the input data and generate video using CogVideoX.
71
+
72
+ Args:
73
+ data (Dict[str, Any]): Input data containing:
74
+ - prompt (str): Text prompt for generation
75
+ - image (str, optional): Base64 encoded image for image-to-video
76
+ - video (str, optional): Base64 encoded video for video-to-video
77
+ - num_inference_steps (int, optional): Number of inference steps
78
+ - guidance_scale (float, optional): Guidance scale for generation
79
+
80
+ Returns:
81
+ Dict[str, Any]: Generated video as base64 string
82
+ """
83
+ # Extract parameters from input
84
+ prompt = data.get("prompt", "")
85
+ num_inference_steps = data.get("num_inference_steps", 50)
86
+ guidance_scale = data.get("guidance_scale", 7.0)
87
+
88
+ # Set up generation parameters
89
+ generation_kwargs = {
90
+ "prompt": prompt,
91
+ "num_inference_steps": num_inference_steps,
92
+ "guidance_scale": guidance_scale,
93
+ "num_videos_per_prompt": 1,
94
+ "use_dynamic_cfg": True,
95
+ "output_type": "np", # Get numpy array output
96
+ }
97
+
98
+ # Handle different input types
99
+ if "image" in data:
100
+ # Image to video generation
101
+ input_image = self._decode_base64_to_image(data["image"])
102
+ input_image = input_image.resize((720, 480)) # Resize as per example
103
+ image = load_image(input_image)
104
+ video_frames = self.pipe_image(
105
+ image=image,
106
+ **generation_kwargs
107
+ ).frames[0]
108
+
109
+ elif "video" in data:
110
+ # Video to video generation
111
+ # TODO: Implement video loading from base64
112
+ # For now, returning error
113
+ return {"error": "Video to video generation not yet implemented"}
114
+
115
+ else:
116
+ # Text to video generation
117
+ generation_kwargs["num_frames"] = 49 # As per example
118
+ video_frames = self.pipe(**generation_kwargs).frames[0]
119
+
120
+ # Convert output to base64
121
+ video_base64 = self._encode_video_to_base64(video_frames)
122
+
123
+ return {
124
+ "video": video_base64
125
+ }
126
+
127
+ def cleanup(self):
128
+ """Cleanup the model and free GPU memory."""
129
+ # Move models to CPU to free GPU memory
130
+ self.pipe.to("cpu")
131
+ self.pipe_video.to("cpu")
132
+ self.pipe_image.to("cpu")
133
+ # Clear CUDA cache
134
+ torch.cuda.empty_cache()