-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy_tbl.R
42 lines (42 loc) · 1.29 KB
/
policy_tbl.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
ThompsonBayesianLinearPolicy <- R6::R6Class(
portable = FALSE,
class = FALSE,
inherit = Policy,
public = list(
class_name = "ThompsonBayesianLinearPolicy",
J = NULL,
P = NULL,
err = NULL,
initialize = function(J = matrix(c(0, 0.25, -0.25), nrow=1, ncol=3, byrow = TRUE),
P = matrix(diag(c(2,2,5)), nrow=3, ncol=3, byrow = TRUE),
err=1) {
super$initialize()
self$J <- J
self$P <- P
self$err <- err
},
set_parameters = function(context_params) {
self$theta <- list('J' = self$J, 'P' = self$P, 'err' = self$err)
},
get_action = function(t, context) {
sigma <- solve(self$theta$P, tol = 1e-200)
mu <- sigma %*% matrix(self$theta$J)
betas <- contextual::mvrnorm(n = 1, mu, sigma)
action$choice <- -(betas[2] / (2*betas[3]))
if(action$choice > 1){
action$choice <- 1
} else if(action$choice < 0) {
action$choice <- 0
}
action
},
set_reward = function(t, context, action, reward) {
y <- reward$reward
x <- action$choice
x <- matrix(c(1,x,x^2), nrow = 1, ncol = 3, byrow = TRUE)
self$theta$J <- (x*y)/self$theta$err + self$theta$J
self$theta$P <- t(x)%*%x + self$theta$P
self$theta
}
)
)