apsys commited on
Commit
ab478b1
1 Parent(s): b7cb866

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -2
main.py CHANGED
@@ -54,11 +54,14 @@ def convert_to_16_bit_wav(data):
54
  raise ValueError("Audio data cannot be converted to " "16-bit int format.")
55
  return data
56
 
 
 
 
57
  # Load the model with INT8 quantization
58
  model = AutoModelForCausalLM.from_pretrained(
59
  model_path,
60
  cache_dir=".",
61
- load_in_8bit=True, # Enable loading in INT8
62
  device_map="auto" # Automatically map model to available devices
63
  )
64
 
@@ -67,7 +70,6 @@ config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json"
67
  ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
68
  quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
69
  quantizer.eval()
70
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
 
72
  # Перемещение всех слоев квантизатора на устройство и их заморозка
73
  def freeze_entire_model(model):
 
54
  raise ValueError("Audio data cannot be converted to " "16-bit int format.")
55
  return data
56
 
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+
59
+
60
  # Load the model with INT8 quantization
61
  model = AutoModelForCausalLM.from_pretrained(
62
  model_path,
63
  cache_dir=".",
64
+ load_in_8bit=True if 'cuda' in device else False, # Enable loading in INT8
65
  device_map="auto" # Automatically map model to available devices
66
  )
67
 
 
70
  ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
71
  quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
72
  quantizer.eval()
 
73
 
74
  # Перемещение всех слоев квантизатора на устройство и их заморозка
75
  def freeze_entire_model(model):