|
from typing import Dict, Any, Union, List |
|
import torch |
|
from diffusers import ( |
|
CogVideoXPipeline, |
|
CogVideoXDPMScheduler, |
|
CogVideoXVideoToVideoPipeline, |
|
CogVideoXImageToVideoPipeline |
|
) |
|
from diffusers.utils import load_video, load_image |
|
from PIL import Image |
|
import base64 |
|
import io |
|
import numpy as np |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
"""Initialize the CogVideoX pipeline. |
|
|
|
Args: |
|
path (str): Path to the model weights |
|
""" |
|
|
|
self.pipe = CogVideoXPipeline.from_pretrained( |
|
path or "jbilcke-hf/CogVideoX-Fun-V1.5-5b-for-InferenceEndpoints", |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
self.pipe.scheduler = CogVideoXDPMScheduler.from_config( |
|
self.pipe.scheduler.config, |
|
timestep_spacing="trailing" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _decode_base64_to_image(self, base64_string: str) -> Image.Image: |
|
"""Convert base64 string to PIL Image.""" |
|
image_data = base64.b64decode(base64_string) |
|
image = Image.open(io.BytesIO(image_data)) |
|
return image |
|
|
|
def _encode_video_to_base64(self, video_frames: List[np.ndarray]) -> str: |
|
"""Convert video frames to base64 string.""" |
|
|
|
import imageio |
|
output_bytes = io.BytesIO() |
|
imageio.mimsave(output_bytes, video_frames, format='mp4', fps=8) |
|
return base64.b64encode(output_bytes.getvalue()).decode('utf-8') |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process the input data and generate video using CogVideoX. |
|
|
|
Args: |
|
data (Dict[str, Any]): Input data containing: |
|
- prompt (str): Text prompt for generation |
|
- image (str, optional): Base64 encoded image for image-to-video |
|
- video (str, optional): Base64 encoded video for video-to-video |
|
- num_inference_steps (int, optional): Number of inference steps |
|
- guidance_scale (float, optional): Guidance scale for generation |
|
|
|
Returns: |
|
Dict[str, Any]: Generated video as base64 string |
|
""" |
|
|
|
prompt = data.get("prompt", "") |
|
num_inference_steps = data.get("num_inference_steps", 50) |
|
guidance_scale = data.get("guidance_scale", 7.0) |
|
|
|
|
|
generation_kwargs = { |
|
"prompt": prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"guidance_scale": guidance_scale, |
|
"num_videos_per_prompt": 1, |
|
"use_dynamic_cfg": True, |
|
"output_type": "np", |
|
} |
|
|
|
|
|
if "image" in data: |
|
|
|
input_image = self._decode_base64_to_image(data["image"]) |
|
input_image = input_image.resize((720, 480)) |
|
image = load_image(input_image) |
|
|
|
|
|
return {"error": "Image to video generation not yet supported"} |
|
|
|
|
|
|
|
|
|
|
|
elif "video" in data: |
|
|
|
|
|
|
|
return {"error": "Video to video generation not yet supported"} |
|
|
|
else: |
|
|
|
generation_kwargs["num_frames"] = 49 |
|
video_frames = self.pipe(**generation_kwargs).frames[0] |
|
|
|
|
|
video_base64 = self._encode_video_to_base64(video_frames) |
|
|
|
return { |
|
"video": video_base64 |
|
} |
|
|
|
def cleanup(self): |
|
"""Cleanup the model and free GPU memory.""" |
|
|
|
self.pipe.to("cpu") |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |