|
import torch |
|
from transformers import HubertModel, HubertConfig |
|
import safetensors.torch as st |
|
from safetensors import safe_open |
|
import json |
|
|
|
class HubertModelWithFinalProj(HubertModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
self.final_proj = torch.nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
|
|
@staticmethod |
|
def load_safetensors(path: str, device="cpu"): |
|
assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'" |
|
|
|
with safe_open(path, framework="pt", device="cpu") as f: |
|
metadata = f.metadata() |
|
state_dict = {} |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
model = HubertModelWithFinalProj(HubertConfig.from_dict(json.loads(metadata["config"]))) |
|
model.load_state_dict(state_dict=state_dict) |
|
return model.to(device) |
|
|
|
def save_safetensors(self, path: str): |
|
assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'" |
|
|
|
with open(path,"wb") as f: |
|
state_dict = self.state_dict() |
|
f.write(st.save(state_dict,dict(config=json.dumps(self.config.to_dict())))) |
|
|
|
def extract_features(self, source: torch.Tensor, version="v2", **kwargs): |
|
with torch.no_grad(): |
|
output_layer = 9 if version == "v1" else 12 |
|
output = self(source.to(self.config.torch_dtype), output_hidden_states=True)["hidden_states"][output_layer] |
|
features = self.final_proj(output) if version == "v1" else output |
|
return features |