import os import threading import time import gradio as gr import torch # from diffusers import CogVideoXPipeline import torch from models.pipeline import VchitectXLPipeline import random import numpy as np import os import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5TokenizerFast, ) from models.modeling_t5 import T5EncoderModel from models.VchitectXL import VchitectXLTransformerModel from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel, CLIPTextModelWithProjection from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( is_torch_xla_available, logging, replace_example_docstring, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline # from patch_conv import convert_model from op_replace import replace_all_layernorms if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False import math from diffusers.utils import export_to_video from datetime import datetime, timedelta # from openai import OpenAI import spaces import moviepy.editor as mp # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps import torch.fft @torch.no_grad() def myfft(tensor): if True: if True: tensor_fft = torch.fft.fft2(tensor) # 将频谱中心移到图像中心 tensor_fft_shifted = torch.fft.fftshift(tensor_fft) # 获取张量的尺寸 B, C, H, W = tensor.size() # 定义频率分离的半径 radius = min(H, W) // 5 # 可以调整此值 # 创建一个中心为(H/2, W/2)的圆形掩码 Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) center_x, center_y = W // 2, H // 2 mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 # 创建高频和低频掩码 low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) high_freq_mask = ~low_freq_mask # 获取低频分量 low_freq_fft = tensor_fft_shifted * low_freq_mask # low_freq_fft_shifted = torch.fft.ifftshift(low_freq_fft) # low_freq = torch.fft.ifft2(low_freq_fft_shifted).real # 获取高频分量 high_freq_fft = tensor_fft_shifted * high_freq_mask # high_freq_fft_shifted = torch.fft.ifftshift(high_freq_fft) # high_freq = torch.fft.ifft2(high_freq_fft_shifted).real return low_freq_fft, high_freq_fft @torch.no_grad() def acc_call( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, frames: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): if True: # print('acc call.......') height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor frames = frames or 24 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self.execution_device ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, frames, prompt_embeds.dtype, device, generator, latents, ) # 6. Denoising loop # with self.progress_bar(total=num_inference_steps) as progress_bar: from tqdm import tqdm for i, t in tqdm(enumerate(timesteps)): if self.interrupt: continue # print(i, t,'******',timesteps) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) noise_pred_text = self.transformer( hidden_states=latent_model_input[1,:].unsqueeze(0), timestep=timestep, encoder_hidden_states=prompt_embeds[1,:].unsqueeze(0), pooled_projections=pooled_prompt_embeds[1,:].unsqueeze(0), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, # idx=i, )[0] if i<30 or (i>30 and i%5==0): noise_pred_uncond = self.transformer( hidden_states=latent_model_input[0,:].unsqueeze(0), timestep=timestep, encoder_hidden_states=prompt_embeds[0,:].unsqueeze(0), pooled_projections=pooled_prompt_embeds[0,:].unsqueeze(0), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, # idx=i, )[0] # print(noise_pred_uncond.shape,noise_pred_text.shape) # exit(0) # torch.Size([80, 16, 54, 96]) torch.Size([80, 16, 54, 96]) if i>=28: lf_uc,hf_uc = myfft(noise_pred_uncond.float()) lf_c, hf_c = myfft(noise_pred_text.float()) delta_lf = lf_uc -lf_c delta_hf = hf_uc - hf_c else: lf_c, hf_c = myfft(noise_pred_text.float()) delta_lf = delta_lf * 1.1 delta_hf = delta_hf * 1.25 new_lf_uc = delta_lf + lf_c new_hf_uc = delta_hf + hf_c combine_uc = new_lf_uc + new_hf_uc combined_fft = torch.fft.ifftshift(combine_uc) noise_pred_uncond = torch.fft.ifft2(combined_fft).real self._guidance_scale = 1 + guidance_scale * ( (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) # perform guidance if self.do_classifier_free_guidance: # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) # call the callback, if provided # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): # progress_bar.update() if XLA_AVAILABLE: xm.mark_step() # if output_type == "latent": # image = latents # else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor videos = [] for v_idx in range(latents.shape[1]): image = self.vae.decode(latents[:,v_idx], return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) videos.append(image[0]) return videos import os from huggingface_hub import login login(token=os.getenv('HF_TOKEN')) dtype = torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device) # pipe.acc_call = acc_call.__get__(pipe) import types # pipe.__call__ = types.MethodType(acc_call, pipe) pipe.__class__.__call__ = acc_call os.makedirs("./output", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) @spaces.GPU(duration=120) def infer(prompt: str, progress=gr.Progress(track_tqdm=True)): torch.cuda.empty_cache() with torch.cuda.amp.autocast(dtype=torch.bfloat16): video = pipe( prompt, negative_prompt="", num_inference_steps=50, guidance_scale=7.5, width=768, height=432, #480x288 624x352 432x240 768x432 frames=16 ) return video def save_video(tensor): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") video_path = f"./output/{timestamp}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) export_to_video(tensor, video_path) return video_path def convert_to_gif(video_path): clip = mp.VideoFileClip(video_path) clip = clip.set_fps(8) clip = clip.resize(height=240) gif_path = video_path.replace(".mp4", ".gif") clip.write_gif(gif_path, fps=8) return gif_path def delete_old_files(): while True: now = datetime.now() cutoff = now - timedelta(minutes=10) directories = ["./output", "./gradio_tmp"] for directory in directories: for filename in os.listdir(directory): file_path = os.path.join(directory, filename) if os.path.isfile(file_path): file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) if file_mtime < cutoff: os.remove(file_path) time.sleep(600) threading.Thread(target=delete_old_files, daemon=True).start() with gr.Blocks() as demo: gr.Markdown("""