Skip to content

Commit

Permalink
Generate Gabor filters from image statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Oct 28, 2016
1 parent a6bf9f1 commit 031b124
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 5 deletions.
41 changes: 41 additions & 0 deletions examples/vision/gabors_for_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Generate Gabor filters for a set of images based on the statistics
of those images.
"""

import matplotlib.pyplot as plt
import numpy as np

from nengo.dists import Uniform
from nengo_extras.data import load_mnist, patches_from_images
from nengo_extras.vision import Gabor, gabors_for_images, gabors_for_patches


images = load_mnist('~/data/mnist.pkl.gz')[0][0].reshape(-1, 28, 28)
# images = Gabor(theta=Uniform(-0.1, 0.1), freq=Uniform(0.5, 1.5)).generate(10000, (28, 28))
# images = Gabor(theta=Uniform(-0.1, 0.1), freq=Uniform(2., 3.)).generate(10000, (28, 28))
# images = Gabor(theta=Uniform(-0.1, 0.1), freq=Uniform(1., 2.)).generate(10000, (28, 28))
# images = Gabor(theta=Uniform(-0.1, 0.1), freq=Uniform(5., 6.)).generate(10000, (28, 28))

patches = patches_from_images(images, 10000, (11, 11))

gabors1 = gabors_for_images(images, 1000, images.shape[-2:])
gabors2 = gabors_for_images(images, 1000, (11, 11))
# gabors2 = gabors_for_patches(images, 1000, (11, 11))

def spectrum(images):
F = np.fft.fft2(images)
Fmean = np.abs(F).mean(0)
Fmean[0, 0] = 0
return np.fft.fftshift(Fmean)

plt.figure()
plt.subplot(221)
plt.imshow(spectrum(images), interpolation='none')
plt.subplot(222)
plt.imshow(spectrum(patches), interpolation='none')
plt.subplot(223)
plt.imshow(spectrum(gabors1), interpolation='none')
plt.subplot(224)
plt.imshow(spectrum(gabors2), interpolation='none')
plt.show()
19 changes: 19 additions & 0 deletions nengo_extras/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,25 @@ def load_mnist(filepath=None, validation=False):
return train_set, test_set


def patches_from_images(images, n_patches, patch_size, rng=np.random):
"""Exctract patches of a given size randomly from images"""
assert images.ndim in (3, 4)
nc0 = None if images.ndim == 3 else images.shape[1]
images = images[:, None, :, :] if nc0 is None else images
n_images, nc, ni, nj = images.shape

pi, pj = patch_size
k = rng.randint(n_images, size=n_patches)
i = rng.randint(ni - pi, size=n_patches)
j = rng.randint(nj - pj, size=n_patches)

patches = np.zeros((n_patches, nc, pi, pj), dtype=images.dtype)
for p, (kk, ii, jj) in enumerate(zip(k, i, j)):
patches[p] = images[kk, :, ii:ii+pi, jj:jj+pj]

return patches[:, 0, :, :] if nc0 is None else patches


def spasafe_names(label_names):
vocab_names = [
(name.split(',')[0] if ',' in name else name).upper().replace(' ', '_')
Expand Down
71 changes: 66 additions & 5 deletions nengo_extras/vision.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import absolute_import

import numpy as np

from nengo.dists import Choice, Uniform, DistributionParam
Expand Down Expand Up @@ -25,11 +27,11 @@ def __init__(self, theta=Uniform(-np.pi, np.pi), freq=Uniform(0.2, 2),

def generate(self, n, shape, rng=np.random, norm=1.):
assert isinstance(shape, tuple) and len(shape) == 2
thetas = self.theta.sample(n, rng=rng)[:, None, None]
freqs = self.freq.sample(n, rng=rng)[:, None, None]
phases = self.phase.sample(n, rng=rng)[:, None, None]
sigma_xs = self.sigma_x.sample(n, rng=rng)[:, None, None]
sigma_ys = self.sigma_y.sample(n, rng=rng)[:, None, None]
thetas = self.theta.sample(n, d=1, rng=rng).reshape(-1, 1, 1)
freqs = self.freq.sample(n, d=1, rng=rng).reshape(-1, 1, 1)
phases = self.phase.sample(n, d=1, rng=rng).reshape(-1, 1, 1)
sigma_xs = self.sigma_x.sample(n, d=1, rng=rng).reshape(-1, 1, 1)
sigma_ys = self.sigma_y.sample(n, d=1, rng=rng).reshape(-1, 1, 1)

x, y = np.linspace(-1, 1, shape[1]), np.linspace(-1, 1, shape[0])
X, Y = np.meshgrid(x, y)
Expand Down Expand Up @@ -97,3 +99,62 @@ def populate(self, filters, rng=np.random, flatten=False):
output[k, :, i[k]:i[k]+shape[0], j[k]:j[k]+shape[1]] = filters[k]

return output.reshape(n, -1) if flatten else output


def image_freq_mixture(images):
"""Create a mixture model distribution for frequencies in an image set.
"""
from nengo_extras.dists import MultivariateGaussian, Mixture

assert images.ndim == 3
n, ni, nj = images.shape

I = np.fft.fft2(images)
S = np.abs(I).mean(axis=0)
S[0, 0] = 0
S /= S.sum()
S = np.fft.fftshift(S, axes=(-2, -1))

dists = []
fis = np.fft.fftshift(np.fft.fftfreq(ni))
fjs = np.fft.fftshift(np.fft.fftfreq(nj))
var_i = (0.5/ni)**2
var_j = (0.5/nj)**2
for fi in fis:
for fj in fjs:
dist = MultivariateGaussian((fi, fj), (var_i, var_j))
dists.append(dist)

return Mixture(dists, p=S.ravel())


def gabors_for_images(images, n_gabors, gabor_size, rng=np.random):
"""Return Gabor encoders with statistics matching images
Currently, this takes statistics across the whole image, and creates
Gabors based on that. However, it would be more effective to determine
Gabors for each location in the image, based on the statistics of that
location.
"""
from nengo_extras.dists import Tile
f_dist = image_freq_mixture(images)
f_pts = f_dist.sample(n_gabors, d=2, rng=rng)
f_pts = f_pts * (0.5 * np.array(gabor_size)) # `Gabor` freqs in window units

thetas = np.arctan2(f_pts[:, 0], f_pts[:, 1])
freqs = np.sqrt((f_pts**2).sum(axis=1))

gabor = Gabor(theta=Tile(thetas), freq=Tile(freqs))
gabors = gabor.generate(n_gabors, gabor_size, rng=rng)

return gabors


def gabors_for_patches(images, n_gabors, patch_size, n_patches=None, rng=np.random):
"""Return Gabor encoders with statistics matching image patches"""
from .data import patches_from_images

if n_patches is None:
n_patches = 5 * n_gabors
patches = patches_from_images(images, n_patches, patch_size, rng=rng)
return gabors_for_images(patches, n_gabors, patch_size, rng=rng)

0 comments on commit 031b124

Please sign in to comment.