diff --git a/econ_layers/layers.py b/econ_layers/layers.py index 00ba934..89a9f14 100644 --- a/econ_layers/layers.py +++ b/econ_layers/layers.py @@ -17,18 +17,18 @@ def __init__( def forward(self, input): return torch.cat([input.pow(m) for m in torch.arange(1, self.n_moments + 1)], 1) - + # rescaling by a specific element of a given input class RescaleOutputsByInput(nn.Module): def __init__(self, rescale_index: int = 0, bias=False): super().__init__() self.rescale_index = rescale_index if bias: - self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here - torch.nn.init.zeros_(self.bias) + self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here + torch.nn.init.ones_(self.bias) else: - self.bias = 0.0 # register_parameter('bias', None) # necessary? - + self.bias = 0.0 # register_parameter('bias', None) # necessary? + def forward(self, x, y): if x.dim() == 1: return x[self.rescale_index] * y + self.bias