-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsnconv.py
101 lines (85 loc) · 3.08 KB
/
snconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import keras.backend as K
from keras.layers.convolutional import _Conv
from keras.legacy import interfaces
from keras.engine import InputSpec
class SNConv2D(_Conv):
@interfaces.legacy_conv2d_support
def __init__(self, filters,
kernel_size,
strides=(1, 1),
padding='valid',
data_format=None,
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(SNConv2D, self).__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.input_spec = InputSpec(ndim=4)
self.Ip = 1
self.u = self.add_weight(
name='W_u',
shape=(1,filters),
initializer='random_uniform',
trainable=False
)
def call(self, inputs):
outputs = K.conv2d(
inputs,
self.W_bar(),
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.use_bias:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def get_config(self):
config = super(SNConv2D, self).get_config()
config.pop('rank')
return config
def W_bar(self):
# Spectrally Normalized Weight
W_mat = K.permute_dimensions(self.kernel, (3, 2, 0, 1)) # (h, w, i, o) => (o, i, h, w)
W_mat = K.reshape(W_mat,[K.shape(W_mat)[0], -1]) # (o, i * h * w)
if not self.Ip >= 1:
raise ValueError("The number of power iterations should be positive integer")
_u = self.u
_v = None
for _ in range(self.Ip):
_v = _l2normalize(K.dot(_u, W_mat))
_u = _l2normalize(K.dot(_v, K.transpose(W_mat)))
sigma = K.sum(K.dot(_u,W_mat)*_v)
K.update(self.u,K.in_train_phase(_u, self.u))
return self.kernel / sigma
def _l2normalize(x):
return x / K.sqrt(K.sum(K.square(x)) + K.epsilon())