Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from torch import nn | |
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer | |
class DualRegressionModel(nn.Module): | |
def __init__( | |
self, | |
model_name_or_path: str = "camembert/camembert-base", | |
loss_aggreatation: str = "mean", | |
): | |
""" | |
This class instantiates the pre-training model. | |
:param model_name_or_path: The name or path of the model to be used for pre-training. | |
""" | |
super().__init__() | |
if "bart" in model_name_or_path: | |
self.model = AutoModel.from_pretrained( | |
model_name_or_path, output_hidden_states=True | |
) | |
self.model = self.model.encoder | |
else: | |
self.model = AutoModelForMaskedLM.from_pretrained( | |
model_name_or_path, output_hidden_states=True | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
self.loss_aggreatation = loss_aggreatation | |
# create two different regression heads for two tasks (latitude and longitude) | |
self.lat_regression_head = torch.nn.Linear(self.model.config.hidden_size, 1) | |
self.long_regression_head = torch.nn.Linear(self.model.config.hidden_size, 1) | |
self.crierion = torch.nn.MSELoss() | |
def forward( | |
self, | |
batch, | |
): | |
""" | |
This function is called to compute the loss for the specified task. | |
:param batch: The batch of data. | |
""" | |
predict = not batch.keys() & {"longitude", "latitude"} | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
if not predict: | |
latitudes = batch["latitude"] | |
longitudes = batch["longitude"] | |
# get the last hidden state | |
last_hidden_state = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
).hidden_states[-1][:, 0, :] | |
lat_predictions = self.lat_regression_head(last_hidden_state) | |
long_predictions = self.long_regression_head(last_hidden_state) | |
result = {"latitude": lat_predictions, "longitude": long_predictions} | |
if not predict: | |
lat_loss = self.crierion(lat_predictions.squeeze(), latitudes) | |
long_loss = self.crierion(long_predictions.squeeze(), longitudes) | |
if self.loss_aggreatation == "mean": | |
loss = (lat_loss + long_loss) / 2 | |
elif self.loss_aggreatation == "sum": | |
loss = lat_loss + long_loss | |
else: | |
raise ValueError("Only mean and sum are supported for loss aggregation") | |
result |= {"loss": loss} | |
return result | |
def save_model(self, path): | |
""" | |
This function is called to save the model to a specified path. E.g. "model.pt" | |
:param path: The path where the model is saved. | |
""" | |
torch.save(self.state_dict(), path) | |
def load_model(self, path): | |
""" | |
This function is called to load the model. | |
:param path: The path where the model is saved. E.g. "model.pt" | |
""" | |
# load the state dict | |
self.load_state_dict(torch.load(path, map_location=torch.device("cpu"))) |