zR commited on
Commit
d366590
1 Parent(s): cc979ab
Files changed (2) hide show
  1. app.py +2 -2
  2. rife_model.py +0 -6
app.py CHANGED
@@ -21,7 +21,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
21
  hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
22
  snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
23
 
24
- pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
25
  pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
26
 
27
  os.makedirs("./output", exist_ok=True)
@@ -110,7 +110,7 @@ def infer(
110
  num_frames=49,
111
  output_type="pt",
112
  guidance_scale=guidance_scale,
113
- generator=torch.Generator(device=device).manual_seed(seed),
114
  ).frames
115
 
116
  return (video_pt, seed)
 
21
  hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
22
  snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
23
 
24
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
25
  pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
26
 
27
  os.makedirs("./output", exist_ok=True)
 
110
  num_frames=49,
111
  output_type="pt",
112
  guidance_scale=guidance_scale,
113
+ generator=torch.Generator(device="cpu").manual_seed(seed),
114
  ).frames
115
 
116
  return (video_pt, seed)
rife_model.py CHANGED
@@ -79,12 +79,6 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
79
 
80
 
81
  def load_rife_model(model_path):
82
- torch.set_grad_enabled(False)
83
- if torch.cuda.is_available():
84
- torch.backends.cudnn.enabled = True
85
- torch.backends.cudnn.benchmark = True
86
- torch.set_default_tensor_type(torch.cuda.FloatTensor)
87
-
88
  model = Model()
89
  model.load_model(model_path, -1)
90
  model.eval()
 
79
 
80
 
81
  def load_rife_model(model_path):
 
 
 
 
 
 
82
  model = Model()
83
  model.load_model(model_path, -1)
84
  model.eval()