File size: 779 Bytes
f6b6982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from fastai.tabular.all import *

def create_params(size):
    return nn.Parameter(torch.zeros(*size).normal_(0, 0.01))

class DotProductBias(Module):
    def __init__(self, n_users, n_items, n_factors, y_range=(0, 1.5)):
        super().__init__()
        self.user_factors = create_params([n_users, n_factors])
        self.user_bias = create_params([n_users])
        self.item_factors = create_params([n_items, n_factors])
        self.item_bias = create_params([n_items])
        self.y_range = y_range
        
    def forward(self, x):
        users = self.user_factors[x[:,0]]
        items = self.item_factors[x[:,1]]
        res = (users * items).sum(dim=1)
        res += self.user_bias[x[:,0]] + self.item_bias[x[:,1]]
        return sigmoid_range(res, *self.y_range)