From 503210c776dc1fa9b1b1cd1e3b422b114be454eb Mon Sep 17 00:00:00 2001 From: Jan Rosa Date: Fri, 26 Aug 2022 12:24:06 -0700 Subject: [PATCH 1/2] different bias initialization --- econ_layers/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/econ_layers/layers.py b/econ_layers/layers.py index 00ba934..2e67dcc 100644 --- a/econ_layers/layers.py +++ b/econ_layers/layers.py @@ -25,7 +25,7 @@ def __init__(self, rescale_index: int = 0, bias=False): self.rescale_index = rescale_index if bias: self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here - torch.nn.init.zeros_(self.bias) + torch.nn.init.ones_(self.bias) else: self.bias = 0.0 # register_parameter('bias', None) # necessary? From 1c9b25aaf5202c91d3df0862aaa917ccf10031ce Mon Sep 17 00:00:00 2001 From: Jan Rosa Date: Fri, 26 Aug 2022 12:37:19 -0700 Subject: [PATCH 2/2] dead end --- econ_layers/layers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/econ_layers/layers.py b/econ_layers/layers.py index 2e67dcc..89a9f14 100644 --- a/econ_layers/layers.py +++ b/econ_layers/layers.py @@ -17,18 +17,18 @@ def __init__( def forward(self, input): return torch.cat([input.pow(m) for m in torch.arange(1, self.n_moments + 1)], 1) - + # rescaling by a specific element of a given input class RescaleOutputsByInput(nn.Module): def __init__(self, rescale_index: int = 0, bias=False): super().__init__() self.rescale_index = rescale_index if bias: - self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here - torch.nn.init.ones_(self.bias) + self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here + torch.nn.init.ones_(self.bias) else: - self.bias = 0.0 # register_parameter('bias', None) # necessary? - + self.bias = 0.0 # register_parameter('bias', None) # necessary? + def forward(self, x, y): if x.dim() == 1: return x[self.rescale_index] * y + self.bias