Bils commited on
Commit
5f601de
·
verified ·
1 Parent(s): 73e3afa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -25,12 +25,7 @@ from TTS.api import TTS
25
  # Diffusers for sound design generation
26
  from diffusers import DiffusionPipeline, AudioLDMPipeline
27
  import diffusers
28
-
29
- # Monkey-patch: Create a patched pipeline class so that any reference to AudioLDM2Pipeline is resolved correctly.
30
- class PatchedAudioLDM2Pipeline(AudioLDMPipeline):
31
- pass
32
-
33
- setattr(diffusers, "AudioLDM2Pipeline", PatchedAudioLDM2Pipeline)
34
 
35
  # ---------------------------------------------------------------------
36
  # Setup Logging and Environment Variables
@@ -107,11 +102,16 @@ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
107
  def get_sound_design_pipeline(model_name: str, token: str):
108
  """
109
  Returns a cached DiffusionPipeline for sound design if available;
110
- otherwise, it loads and caches the pipeline using the patched pipeline class.
 
 
 
111
  """
 
 
112
  if model_name in SOUND_DESIGN_PIPELINES:
113
  return SOUND_DESIGN_PIPELINES[model_name]
114
- pipe = DiffusionPipeline.from_pretrained(model_name, pipeline_class=PatchedAudioLDM2Pipeline, use_auth_token=token)
115
  SOUND_DESIGN_PIPELINES[model_name] = pipe
116
  return pipe
117
 
 
25
  # Diffusers for sound design generation
26
  from diffusers import DiffusionPipeline, AudioLDMPipeline
27
  import diffusers
28
+ from packaging import version
 
 
 
 
 
29
 
30
  # ---------------------------------------------------------------------
31
  # Setup Logging and Environment Variables
 
102
  def get_sound_design_pipeline(model_name: str, token: str):
103
  """
104
  Returns a cached DiffusionPipeline for sound design if available;
105
+ otherwise, it loads and caches the pipeline.
106
+
107
+ NOTE: AudioLDM2Pipeline is available only in diffusers>=0.21.0.
108
+ Since your requirements fix diffusers==0.20.2, this function will raise an error.
109
  """
110
+ if version.parse(diffusers.__version__) < version.parse("0.21.0"):
111
+ raise ValueError("AudioLDM2 requires diffusers>=0.21.0. Please upgrade your diffusers package.")
112
  if model_name in SOUND_DESIGN_PIPELINES:
113
  return SOUND_DESIGN_PIPELINES[model_name]
114
+ pipe = DiffusionPipeline.from_pretrained(model_name, pipeline_class=AudioLDMPipeline, use_auth_token=token)
115
  SOUND_DESIGN_PIPELINES[model_name] = pipe
116
  return pipe
117