Skip to content

Commit

Permalink
dead end
Browse files Browse the repository at this point in the history
  • Loading branch information
janrosa1 committed Aug 26, 2022
1 parent 503210c commit 1c9b25a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions econ_layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ def __init__(
def forward(self, input):
return torch.cat([input.pow(m) for m in torch.arange(1, self.n_moments + 1)], 1)


# rescaling by a specific element of a given input
class RescaleOutputsByInput(nn.Module):
def __init__(self, rescale_index: int = 0, bias=False):
super().__init__()
self.rescale_index = rescale_index
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here
torch.nn.init.ones_(self.bias)
self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here
torch.nn.init.ones_(self.bias)
else:
self.bias = 0.0 # register_parameter('bias', None) # necessary?
self.bias = 0.0 # register_parameter('bias', None) # necessary?

def forward(self, x, y):
if x.dim() == 1:
return x[self.rescale_index] * y + self.bias
Expand Down

0 comments on commit 1c9b25a

Please sign in to comment.