-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrealworld_main.py
215 lines (174 loc) · 8.3 KB
/
realworld_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Mini main for testing algorithms. """
import numpy as np
import jax
import jax.numpy as jnp
from easydict import EasyDict as edict
import os
from core.contextual_bandit import contextual_bandit_runner
from algorithms.neural_offline_bandit import ExactNeuraLCBV2, NeuralGreedyV2, ApproxNeuraLCBV2
from algorithms.lin_lcb import LinLCB
from algorithms.kern_lcb import KernLCB
from algorithms.uniform_sampling import UniformSampling
from algorithms.neural_lin_lcb import ExactNeuralLinLCBV2, ExactNeuralLinGreedyV2, ApproxNeuralLinLCBV2, ApproxNeuralLinGreedyV2, \
ApproxNeuralLinLCBJointModel, NeuralLinGreedyJointModel
from data.realworld_data import *
from absl import flags, app
FLAGS = flags.FLAGS
flags.DEFINE_string('data_type', 'mushroom', 'Dataset to sample from')
flags.DEFINE_string('policy', 'eps-greedy', 'Offline policy, eps-greedy/subset')
flags.DEFINE_float('eps', 0.1, 'Probability of selecting a random action in eps-greedy')
flags.DEFINE_float('subset_r', 0.5, 'The ratio of the action spaces to be selected in offline data')
flags.DEFINE_integer('num_contexts', 15000, 'Number of contexts for training.')
flags.DEFINE_integer('num_test_contexts', 10000, 'Number of contexts for test.')
flags.DEFINE_boolean('verbose', True, 'verbose')
flags.DEFINE_boolean('debug', True, 'debug')
flags.DEFINE_boolean('normalize', False, 'normalize the regret')
flags.DEFINE_integer('update_freq', 1, 'Update frequency')
flags.DEFINE_integer('freq_summary', 10, 'Summary frequency')
flags.DEFINE_integer('test_freq', 10, 'Test frequency')
flags.DEFINE_string('algo_group', 'approx-neural', 'baseline/neural')
flags.DEFINE_integer('num_sim', 10, 'Number of simulations')
flags.DEFINE_float('noise_std', 0.1, 'Noise std')
flags.DEFINE_integer('chunk_size', 500, 'Chunk size')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('num_steps', 100, 'Number of steps to train NN.')
flags.DEFINE_integer('buffer_s', -1, 'Size in the train data buffer.')
flags.DEFINE_bool('data_rand', True, 'Where randomly sample a data batch or use the latest samples in the buffer' )
flags.DEFINE_float('rbf_sigma', 1, 'RBF sigma for KernLCB') # [0.1, 1, 10]
# NeuraLCB
flags.DEFINE_float('beta', 0.1, 'confidence paramter') # [0.01, 0.05, 0.1, 0.5, 1, 5, 10]
flags.DEFINE_float('lr', 1e-3, 'learning rate')
flags.DEFINE_float('lambd0', 0.1, 'minimum eigenvalue')
flags.DEFINE_float('lambd', 1e-4, 'regularization parameter')
#================================================================
# Network parameters
#================================================================
def main(unused_argv):
#=================
# Data
#=================
if FLAGS.policy == 'eps-greedy':
policy_prefix = '{}{}'.format(FLAGS.policy, FLAGS.eps)
elif FLAGS.policy == 'subset':
policy_prefix = '{}{}'.format(FLAGS.policy, FLAGS.subset_ratio)
elif FLAGS.policy == 'online':
policy_prefix = '{}{}'.format(FLAGS.policy, FLAGS.eps)
else:
raise NotImplementedError('{} not implemented'.format(FLAGS.policy))
dataclasses = {'mushroom':MushroomData, 'jester':JesterData, 'statlog':StatlogData, 'covertype':CoverTypeData, 'stock': StockData,
'adult': AdultData, 'census': CensusData, 'mnist': MnistData
}
if FLAGS.data_type in dataclasses:
DataClass = dataclasses[FLAGS.data_type]
data = DataClass(num_contexts=FLAGS.num_contexts,
num_test_contexts=FLAGS.num_test_contexts,
pi = FLAGS.policy,
eps = FLAGS.eps,
subset_r = FLAGS.subset_r)
else:
raise NotImplementedError
if FLAGS.data_type == 'mnist': # Use 1000 test points for mnist
FLAGS.num_test_contexts = 1000
FLAGS.test_freq = 100
FLAGS.chunk_size = 1
dataset = data.reset_data()
context_dim = dataset[0].shape[1]
num_actions = data.num_actions
hparams = edict({
'layer_sizes': [100,100],
's_init': 1,
'activation': jax.nn.relu,
'layer_n': True,
'seed': 0,
'context_dim': context_dim,
'num_actions': num_actions,
'beta': FLAGS.beta, # [0.01, 0.05, 0.1, 0.5, 1, 5, 10]
'lambd': FLAGS.lambd, # regularization param: [0.1m, m, 10 m ]
'lr': FLAGS.lr,
'lambd0': FLAGS.lambd0, # shoud be lambd/m in theory but we fix this at 0.1 for simplicity and mainly focus on tuning beta
'verbose': False,
'batch_size': FLAGS.batch_size,
'freq_summary': FLAGS.freq_summary,
'chunk_size': FLAGS.chunk_size,
'num_steps': FLAGS.num_steps,
'buffer_s': FLAGS.buffer_s,
'data_rand': FLAGS.data_rand,
'debug_mode': 'full' # simple/full
})
lin_hparams = edict(
{
'context_dim': hparams.context_dim,
'num_actions': hparams.num_actions,
'lambd0': hparams.lambd0,
'beta': hparams.beta,
'rbf_sigma': FLAGS.rbf_sigma, # 0.1, 1, 10
'max_num_sample': 1000
}
)
data_prefix = '{}_d={}_a={}_pi={}_std={}'.format(FLAGS.data_type, \
context_dim, num_actions, policy_prefix, data.noise_std)
res_dir = os.path.join('results', data_prefix)
if not os.path.exists(res_dir):
os.makedirs(res_dir)
#================================================================
# Algorithms
#================================================================
if FLAGS.algo_group == 'approx-neural':
algos = [
UniformSampling(lin_hparams),
# NeuralGreedyV2(hparams, update_freq = FLAGS.update_freq),
ApproxNeuraLCBV2(hparams, update_freq = FLAGS.update_freq)
]
algo_prefix = 'approx-neural-gridsearch_epochs={}_m={}_layern={}_buffer={}_bs={}_lr={}_beta={}_lambda={}_lambda0={}'.format(
hparams.num_steps, min(hparams.layer_sizes), hparams.layer_n, hparams.buffer_s, hparams.batch_size, hparams.lr, \
hparams.beta, hparams.lambd, hparams.lambd0
)
if FLAGS.algo_group == 'neural-greedy':
algos = [
UniformSampling(lin_hparams),
NeuralGreedyV2(hparams, update_freq = FLAGS.update_freq),
]
algo_prefix = 'neural-greedy-gridsearch_epochs={}_m={}_layern={}_buffer={}_bs={}_lr={}_lambda={}'.format(
hparams.num_steps, min(hparams.layer_sizes), hparams.layer_n, hparams.buffer_s, hparams.batch_size, hparams.lr, \
hparams.lambd
)
if FLAGS.algo_group == 'baseline':
algos = [
UniformSampling(lin_hparams),
LinLCB(lin_hparams),
## KernLCB(lin_hparams),
# NeuralGreedyV2(hparams, update_freq = FLAGS.update_freq),
# ApproxNeuralLinLCBV2(hparams),
# ApproxNeuralLinGreedyV2(hparams),
NeuralLinGreedyJointModel(hparams),
ApproxNeuralLinLCBJointModel(hparams)
]
algo_prefix = 'baseline_epochs={}_m={}_layern={}_beta={}_lambda0={}_rbf-sigma={}_maxnum={}'.format(
hparams.num_steps, min(hparams.layer_sizes), hparams.layer_n, \
hparams.beta, hparams.lambd0, lin_hparams.rbf_sigma, lin_hparams.max_num_sample
)
if FLAGS.algo_group == 'kern': # for tuning KernLCB
algos = [
UniformSampling(lin_hparams),
KernLCB(lin_hparams),
]
algo_prefix = 'kern-gridsearch_beta={}_rbf-sigma={}_maxnum={}'.format(
hparams.beta, lin_hparams.rbf_sigma, lin_hparams.max_num_sample
)
if FLAGS.algo_group == 'neurallinlcb': # Tune NeuralLinLCB seperately
algos = [
UniformSampling(lin_hparams),
ApproxNeuralLinLCBJointModel(hparams)
]
algo_prefix = 'neurallinlcb-gridsearch_m={}_layern={}_beta={}_lambda0={}'.format(
min(hparams.layer_sizes), hparams.layer_n, hparams.beta, hparams.lambd0
)
#==============================
# Runner
#==============================
file_name = os.path.join(res_dir, algo_prefix) + '.npz'
regrets, errs = contextual_bandit_runner(algos, data, FLAGS.num_sim,
FLAGS.update_freq, FLAGS.test_freq, FLAGS.verbose, FLAGS.debug, FLAGS.normalize, file_name)
np.savez(file_name, regrets, errs)
if __name__ == '__main__':
app.run(main)