jbilcke-hf's picture
jbilcke-hf HF staff
Update handler.py
bcdcfae verified
from typing import Dict, Any, Union, List
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXDPMScheduler,
CogVideoXVideoToVideoPipeline,
CogVideoXImageToVideoPipeline
)
from diffusers.utils import load_video, load_image
from PIL import Image
import base64
import io
import numpy as np
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the CogVideoX pipeline.
Args:
path (str): Path to the model weights
"""
# Initialize pipeline with bfloat16 for optimal performance as recommended in docs
self.pipe = CogVideoXPipeline.from_pretrained(
path or "jbilcke-hf/CogVideoX-Fun-V1.5-5b-for-InferenceEndpoints",
torch_dtype=torch.bfloat16
).to("cuda")
# Set up the scheduler with trailing timesteps as shown in example
self.pipe.scheduler = CogVideoXDPMScheduler.from_config(
self.pipe.scheduler.config,
timestep_spacing="trailing"
)
# those two pipelines - generated by Claude - are interesting, but loading it all at once is too much.
# # Initialize video-to-video pipeline
# self.pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
# path or "jbilcke-hf/CogVideoX-Fun-V1.5-5b-for-InferenceEndpoints",
# transformer=self.pipe.transformer,
# vae=self.pipe.vae,
# scheduler=self.pipe.scheduler,
# tokenizer=self.pipe.tokenizer,
# text_encoder=self.pipe.text_encoder,
# torch_dtype=torch.bfloat16
# ).to("cuda")
#
# # Initialize image-to-video pipeline
# self.pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
# path or "THUDM/CogVideoX1.5-5B-I2V",
# vae=self.pipe.vae,
# scheduler=self.pipe.scheduler,
# tokenizer=self.pipe.tokenizer,
# text_encoder=self.pipe.text_encoder,
# torch_dtype=torch.bfloat16
# ).to("cuda")
def _decode_base64_to_image(self, base64_string: str) -> Image.Image:
"""Convert base64 string to PIL Image."""
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image
def _encode_video_to_base64(self, video_frames: List[np.ndarray]) -> str:
"""Convert video frames to base64 string."""
# Convert frames to a video file in memory
import imageio
output_bytes = io.BytesIO()
imageio.mimsave(output_bytes, video_frames, format='mp4', fps=8)
return base64.b64encode(output_bytes.getvalue()).decode('utf-8')
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process the input data and generate video using CogVideoX.
Args:
data (Dict[str, Any]): Input data containing:
- prompt (str): Text prompt for generation
- image (str, optional): Base64 encoded image for image-to-video
- video (str, optional): Base64 encoded video for video-to-video
- num_inference_steps (int, optional): Number of inference steps
- guidance_scale (float, optional): Guidance scale for generation
Returns:
Dict[str, Any]: Generated video as base64 string
"""
# Extract parameters from input
prompt = data.get("prompt", "")
num_inference_steps = data.get("num_inference_steps", 50)
guidance_scale = data.get("guidance_scale", 7.0)
# Set up generation parameters
generation_kwargs = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_videos_per_prompt": 1,
"use_dynamic_cfg": True,
"output_type": "np", # Get numpy array output
}
# Handle different input types
if "image" in data:
# Image to video generation
input_image = self._decode_base64_to_image(data["image"])
input_image = input_image.resize((720, 480)) # Resize as per example
image = load_image(input_image)
#raise ValueError("image to video isn't supported yet (takes up too much RAM right now)")
return {"error": "Image to video generation not yet supported"}
#video_frames = self.pipe_image(
# image=image,
# **generation_kwargs
#).frames[0]
elif "video" in data:
# Video to video generation
# TODO: Implement video loading from base64
# For now, returning error
return {"error": "Video to video generation not yet supported"}
else:
# Text to video generation
generation_kwargs["num_frames"] = 49 # As per example
video_frames = self.pipe(**generation_kwargs).frames[0]
# Convert output to base64
video_base64 = self._encode_video_to_base64(video_frames)
return {
"video": video_base64
}
def cleanup(self):
"""Cleanup the model and free GPU memory."""
# Move models to CPU to free GPU memory
self.pipe.to("cpu")
#self.pipe_video.to("cpu")
#self.pipe_image.to("cpu")
# Clear CUDA cache
torch.cuda.empty_cache()