#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Positionwise feed forward layer definition.""" import torch from funasr_detach.models.transformer.layer_norm import LayerNorm class PositionwiseFeedForwardDecoderSANM(torch.nn.Module): """Positionwise feed forward layer. Args: idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. """ def __init__( self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU() ): """Construct an PositionwiseFeedForward object.""" super(PositionwiseFeedForwardDecoderSANM, self).__init__() self.w_1 = torch.nn.Linear(idim, hidden_units) self.w_2 = torch.nn.Linear( hidden_units, idim if adim is None else adim, bias=False ) self.dropout = torch.nn.Dropout(dropout_rate) self.activation = activation self.norm = LayerNorm(hidden_units) def forward(self, x): """Forward function.""" return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))