-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_solver.py
58 lines (42 loc) · 2.01 KB
/
create_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def create_solver(train_net_path, test_net_path, base_lr, id):
import sys
from caffe.proto import caffe_pb2
s = caffe_pb2.SolverParameter()
# Specify locations of the train and (maybe) test networks.
s.train_net = train_net_path
s.test_net.append(test_net_path)
s.test_interval = 100000 # Test after every 1000 training iterations.
s.test_iter.append(10) # Test on 100 batches each time we test.
# The number of iterations over which to average the gradient.
# Effectively boosts the training batch size by the given factor, without
# affecting memory utilization.
s.iter_size = 1
s.max_iter = 100000 # # of times to update the net (training iterations)
# Solve using the stochastic gradient descent (SGD) algorithm.
# Other choices include 'Adam' and 'RMSProp'.
s.type = 'SGD'
# Set the initial learning rate for SGD.
s.base_lr = base_lr
# Set `lr_policy` to define how the learning rate changes during training.
# Here, we 'step' the learning rate by multiplying it by a factor `gamma`
# every `stepsize` iterations.
s.lr_policy = 'step'
s.gamma = 0.1
s.stepsize = 20000
# Set other SGD hyperparameters. Setting a non-zero `momentum` takes a
# weighted average of the current gradient and previous gradients to make
# learning more stable. L2 weight decay regularizes learning, to help prevent
# the model from overfitting.
s.momentum = 0.9
s.weight_decay = 0.004
# Display the current training loss and accuracy every 1000 iterations.
s.display = 1000000
# Snapshots are files used to store networks we've trained. Here, we'll
# snapshot every 10K iterations -- ten times during training.
s.snapshot = 500
s.snapshot_prefix = '../../datasets/102flowers/snapshots/' + id
# Train on the GPU. Using the CPU to train large networks is very slow.
s.solver_mode = caffe_pb2.SolverParameter.GPU
with open('solver.prototxt', 'w') as f:
f.write(str(s))
return f.name