cocktailpeanut commited on
Commit
f2e9dae
·
1 Parent(s): d6861ba
Files changed (2) hide show
  1. app.py +9 -1
  2. requirements.txt +6 -4
app.py CHANGED
@@ -25,6 +25,7 @@ import gradio as gr
25
  import librosa
26
  import torch
27
  import torchaudio
 
28
 
29
  torchaudio.set_audio_backend("soundfile")
30
 
@@ -190,6 +191,9 @@ def inference(
190
  if torch.cuda.is_available():
191
  torch.cuda.empty_cache()
192
  gc.collect()
 
 
 
193
 
194
 
195
  def inference_with_auto_rerank(
@@ -341,6 +345,9 @@ def change_if_load_asr_model(if_load):
341
  if torch.cuda.is_available():
342
  torch.cuda.empty_cache()
343
  gc.collect()
 
 
 
344
  return gr.Checkbox(label="Load faster whisper model", value=if_load)
345
 
346
 
@@ -602,7 +609,8 @@ def parse_args():
602
  default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
603
  )
604
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
605
- parser.add_argument("--device", type=str, default="cuda")
 
606
  parser.add_argument("--half", action="store_true")
607
  parser.add_argument("--compile", action="store_true",default=True)
608
  parser.add_argument("--max-gradio-length", type=int, default=0)
 
25
  import librosa
26
  import torch
27
  import torchaudio
28
+ import devicetorch
29
 
30
  torchaudio.set_audio_backend("soundfile")
31
 
 
191
  if torch.cuda.is_available():
192
  torch.cuda.empty_cache()
193
  gc.collect()
194
+ elif torch.backends.mps.is_available():
195
+ torch.mps.empty_cache()
196
+ gc.collect()
197
 
198
 
199
  def inference_with_auto_rerank(
 
345
  if torch.cuda.is_available():
346
  torch.cuda.empty_cache()
347
  gc.collect()
348
+ elif torch.backends.mps.is_available():
349
+ torch.mps.empty_cache()
350
+ gc.collect()
351
  return gr.Checkbox(label="Load faster whisper model", value=if_load)
352
 
353
 
 
609
  default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
610
  )
611
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
612
+ #parser.add_argument("--device", type=str, default="cuda")
613
+ parser.add_argument("--device", type=str, default=devicetorch.get(torch))
614
  parser.add_argument("--half", action="store_true")
615
  parser.add_argument("--compile", action="store_true",default=True)
616
  parser.add_argument("--max-gradio-length", type=int, default=0)
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- torch==2.3.0
2
- torchaudio
3
  transformers>=4.35.2
4
  datasets>=2.14.5
5
  lightning>=2.1.0
@@ -22,9 +22,11 @@ vector_quantize_pytorch>=1.14.7
22
  samplerate>=0.2.1
23
  resampy>=0.4.3
24
  spaces>=0.26.1
25
- einx[torch]==0.2.0
 
26
  opencc
27
  faster-whisper
28
  ormsgpack
29
  ffmpeg
30
- soundfile
 
 
1
+ #torch==2.3.0
2
+ #torchaudio
3
  transformers>=4.35.2
4
  datasets>=2.14.5
5
  lightning>=2.1.0
 
22
  samplerate>=0.2.1
23
  resampy>=0.4.3
24
  spaces>=0.26.1
25
+ #einx[torch]==0.2.0
26
+ einx[torch]
27
  opencc
28
  faster-whisper
29
  ormsgpack
30
  ffmpeg
31
+ soundfile
32
+ devicetorch