-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbnn_wrapper.py
145 lines (125 loc) · 5.43 KB
/
bnn_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Wrapping """
import logging
import torch
from typing import Callable, Dict, Optional, Tuple
from .parameters import is_parameter_handled, sample_parameters, estimate_parameters_nll
from .parameters import StateDict, NLLs, ParamsKey
from .predictive import sample_predictive, predictive_likelihoods
class BayesianNeuralNetwork:
"""Manages sampling and NLL calculation for parameters of a native module."""
def __init__(self, module: torch.nn.Module) -> None:
self.parameters2sampler = {} # posterior sampling
self.variational_params = [] # parameters of the samplers
self.parameters2nllfunc = {} # prior densities
self._module = module
self.predictive_distribution_sampler = None
self.predictive_distribution_log_lik = None
def set_posterior_sampler(
self,
parameters: ParamsKey,
sampler: Callable,
variational_params: Dict[str, torch.tensor],
) -> None:
"""Register a sampler for a parameter or parameters."""
if parameters in self.parameters2sampler:
raise Exception(f"{parameters} is already handled!")
self.parameters2sampler[parameters] = sampler
prefix = parameters if isinstance(parameters, str) else "_".join(parameters)
self.variational_params.extend(
(prefix + ":" + vn, vp) for vn, vp in variational_params.items()
)
logging.info(
f"[{self.__class__.__name__}] posterior for {parameters} set to {sampler}({variational_params.keys()})"
)
def set_posterior_samplers(
self,
create_sampler_func: Callable,
filter: Callable = lambda parameter_name: True,
) -> None:
"""Register samplers for selected parameters (e.g. with 'bias' in name)."""
for parameter_name, parameter_value in self._module.named_parameters():
if filter(parameter_name) and not self.is_parameter_already_handled(
parameter_name
):
(
sampler,
variational_params,
_,
) = create_sampler_func(parameter_value)
self.set_posterior_sampler(parameter_name, sampler, variational_params)
def set_prior_density(self, parameters: ParamsKey, nll_func: Callable) -> None:
"""Register NLL calculation for a parameter or parameters."""
if parameters in self.parameters2nllfunc:
raise Exception(f"{parameters} is already handled!")
self.parameters2nllfunc[parameters] = nll_func
logging.info(
f"[{self.__class__.__name__}] prior for {parameters} set to {nll_func}"
)
def set_prior_densities(
self,
create_density_func: Callable,
filter: Callable = lambda parameter_name: True,
):
"""Register NLL calculation for selected parameters (e.g. with 'bias' in name).."""
for parameter_name, parameter_value in self._module.named_parameters():
if filter(parameter_name):
nllfunc = create_density_func(parameter_value.shape)
self.set_prior_density(parameter_name, nllfunc)
def is_parameter_already_handled(self, parameter_name: str) -> bool:
return is_parameter_handled(self.parameters2sampler.items(), parameter_name)
def sample_posterior(self, n_samples: int = 1) -> Tuple[StateDict, NLLs]:
"""Returns samples + NLLs from pre-registered samplers."""
parameters_samples, posterior_nlls = sample_parameters(
self.parameters2sampler.items(), n_samples=n_samples
)
posterior_nlls = torch.stack(
list(posterior_nlls.values())
) # out shape: n_param_groups x n_samples
return parameters_samples, posterior_nlls
def prior_nll(self, parameters_samples: StateDict) -> torch.tensor:
"""Returns samples' NLLs for pre-registered priors."""
prior_nlls = estimate_parameters_nll(
self.parameters2nllfunc, parameters_samples
)
prior_nlls = torch.stack(
list(prior_nlls.values())
) # out shape: n_param_groups x n_posterior_samples
return prior_nlls
def _get_samples(self, parameters_samples, n_samples):
if not parameters_samples:
parameters_samples, _ = self.sample_posterior(n_samples)
return parameters_samples
def sample_predictive(
self,
input_x: torch.Tensor,
parameters_samples: Optional[StateDict] = None,
n_samples: int = 1,
n_predictive_samples: int = 1,
**sample_predictive_kwargs,
):
parameters_samples = self._get_samples(parameters_samples, n_samples)
return sample_predictive(
input_x,
self._module,
parameters_samples,
self.predictive_distribution_sampler,
n_samples=n_predictive_samples,
**sample_predictive_kwargs,
)
def predictive_likelihoods(
self,
input_x: torch.Tensor,
output_x: torch.Tensor,
parameters_samples: Optional[StateDict] = None,
n_samples: int = 1,
**predictive_likelihoods_kwargs,
):
parameters_samples = self._get_samples(parameters_samples, n_samples)
return predictive_likelihoods(
input_x,
output_x,
self._module,
parameters_samples,
self.predictive_distribution_log_lik,
**predictive_likelihoods_kwargs,
)