import io import os from typing import List import PIL.Image import requests import torch from diffusers import AutoencoderTiny, StableDiffusionPipeline from streamdiffusion import StreamDiffusion from streamdiffusion.image_utils import postprocess_image def download_image(url: str): response = requests.get(url) image = PIL.Image.open(io.BytesIO(response.content)) return image class StreamDiffusionWrapper: def __init__( self, model_id: str, lcm_lora_id: str, vae_id: str, device: str, dtype: str, t_index_list: List[int], warmup: int, safety_checker: bool, ): self.device = device self.dtype = dtype self.prompt = "" self.batch_size = len(t_index_list) self.stream = self._load_model( model_id=model_id, lcm_lora_id=lcm_lora_id, vae_id=vae_id, t_index_list=t_index_list, warmup=warmup, ) self.safety_checker = None if safety_checker: from transformers import CLIPFeatureExtractor from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ).to(self.device) self.feature_extractor = CLIPFeatureExtractor.from_pretrained( "openai/clip-vit-base-patch32" ) self.nsfw_fallback_img = PIL.Image.new("RGB", (512, 512), (0, 0, 0)) self.stream.prepare("") def _load_model( self, model_id: str, lcm_lora_id: str, vae_id: str, t_index_list: List[int], warmup: int, ): if os.path.exists(model_id): pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file( model_id ).to(device=self.device, dtype=self.dtype) else: pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( model_id ).to(device=self.device, dtype=self.dtype) stream = StreamDiffusion( pipe=pipe, t_index_list=t_index_list, torch_dtype=self.dtype, is_drawing=True, ) stream.load_lcm_lora(lcm_lora_id) stream.fuse_lora() stream.vae = AutoencoderTiny.from_pretrained(vae_id).to( device=pipe.device, dtype=pipe.dtype ) try: from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt stream = accelerate_with_tensorrt( stream, "engines", max_batch_size=self.batch_size, engine_build_options={"build_static_batch": False}, ) print("TensorRT acceleration enabled.") except Exception: print("TensorRT acceleration has failed. Trying to use Stable Fast.") try: from streamdiffusion.acceleration.sfast import ( accelerate_with_stable_fast, ) stream = accelerate_with_stable_fast(stream) print("StableFast acceleration enabled.") except Exception: print("StableFast acceleration has failed. Using normal mode.") pass stream.prepare( "", num_inference_steps=50, generator=torch.manual_seed(2), ) # warmup for _ in range(warmup): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() stream.txt2img() end.record() torch.cuda.synchronize() return stream def __call__(self, prompt: str) -> PIL.Image.Image: if self.prompt != prompt: self.stream.update_prompt(prompt) self.prompt = prompt for i in range(self.batch_size): x_output = self.stream.txt2img() x_output = self.stream.txt2img() image = postprocess_image(x_output, output_type="pil")[0] if self.safety_checker: safety_checker_input = self.feature_extractor( image, return_tensors="pt" ).to(self.device) _, has_nsfw_concept = self.safety_checker( images=x_output, clip_input=safety_checker_input.pixel_values.to(self.dtype), ) image = self.nsfw_fallback_img if has_nsfw_concept[0] else image return image if __name__ == "__main__": wrapper = StreamDiffusionWrapper(10, 10) wrapper()