Training

Continuing from the discussion of inference, we show how to train JPEG models. Starting with boilerplate code

In [1]:
!pip install torch opt_einsum tabulate torchvision

import torch
from torchvision import datasets, transforms
import opt_einsum as oe
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate
Collecting torch
  Downloading https://files.pythonhosted.org/packages/7e/60/66415660aa46b23b5e1b72bc762e816736ce8d7260213e22365af51e8f9c/torch-1.0.0-cp36-cp36m-manylinux1_x86_64.whl (591.8MB)
    100% |████████████████████████████████| 591.8MB 31kB/s 
tcmalloc: large alloc 1073750016 bytes == 0x60f48000 @  0x7f8bda3f82a4 0x591a07 0x5b5d56 0x502e9a 0x506859 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x504c28 0x502540 0x502f3d 0x507641
Collecting opt_einsum
  Downloading https://files.pythonhosted.org/packages/f6/d6/44792ec668bcda7d91913c75237314e688f70415ab2acd7172c845f0b24f/opt_einsum-2.3.2.tar.gz (59kB)
    100% |████████████████████████████████| 61kB 22.0MB/s 
Requirement already satisfied: tabulate in /usr/local/lib/python3.6/dist-packages (0.8.2)
Collecting torchvision
  Downloading https://files.pythonhosted.org/packages/ca/0d/f00b2885711e08bd71242ebe7b96561e6f6d01fdb4b9dcf4d37e2e13c5e1/torchvision-0.2.1-py2.py3-none-any.whl (54kB)
    100% |████████████████████████████████| 61kB 19.7MB/s 
Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.6/dist-packages (from opt_einsum) (1.14.6)
Collecting pillow>=4.1.1 (from torchvision)
  Downloading https://files.pythonhosted.org/packages/92/e3/217dfd0834a51418c602c96b110059c477260c7fee898542b100913947cf/Pillow-5.4.0-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)
    100% |████████████████████████████████| 2.0MB 1.9MB/s 
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)
Building wheels for collected packages: opt-einsum
  Running setup.py bdist_wheel for opt-einsum ... - done
  Stored in directory: /root/.cache/pip/wheels/51/3e/a3/b351fae0cbf15373c2136a54a70f43fea5fe91d8168a5faaa4
Successfully built opt-einsum
Installing collected packages: torch, opt-einsum, pillow, torchvision
  Found existing installation: Pillow 4.0.0
    Uninstalling Pillow-4.0.0:
      Successfully uninstalled Pillow-4.0.0
Successfully installed opt-einsum-2.3.2 pillow-5.4.0 torch-1.0.0 torchvision-0.2.1

and the standard definition of the tensors we will use for compression and decompression

In [0]:
def A(alpha):
    if alpha == 0:
        return 1.0 / np.sqrt(2)
    else:
        return 1


def D():
    D_t = torch.zeros([8, 8, 8, 8], dtype=torch.float)

    for i in range(8):
        for j in range(8):
            for alpha in range(8):
                for beta in range(8):
                    scale_a = A(alpha)
                    scale_b = A(beta)

                    coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
                    coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)

                    D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y

    return D_t


def D_n(n_freqs):
    D_t = torch.zeros([8, 8, 8, 8], dtype=torch.float)

    for i in range(8):
        for j in range(8):
            for alpha in range(8):
                for beta in range(8):
                    if alpha + beta <= n_freqs:
                        scale_a = A(alpha)
                        scale_b = A(beta)

                        coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
                        coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)

                        D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y

    return D_t


def Z():
    z = np.array([[ 0,  1,  5,  6, 14, 15, 27, 28],
                  [ 2,  4,  7, 13, 16, 26, 29, 42],
                  [ 3,  8, 12, 17, 25, 30, 41, 43],
                  [ 9, 11, 18, 24, 31, 40, 44, 53],
                  [10, 19, 23, 32, 39, 45, 52, 54],
                  [20, 22, 33, 38, 46, 51, 55, 60],
                  [21, 34, 37, 47, 50, 56, 59, 61],
                  [35, 36, 48, 49, 57, 58, 62, 63]], dtype=float)

    Z_t = torch.zeros([8, 8, 64], dtype=torch.float)

    for alpha in range(8):
        for beta in range(8):
            for gamma in range(64):
                if z[alpha, beta] == gamma:
                    Z_t[alpha, beta, gamma] = 1

    return Z_t


def S():
    q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
                   27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
                   29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
                   35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)

    S_t = torch.zeros([64, 64], dtype=torch.float)

    for gamma in range(64):
        for k in range(64):
            if gamma == k:
                S_t[gamma, k] = 1.0 / q[k]

    return S_t


def S_i():
    q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
                   27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
                   29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
                   35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)

    S_t = torch.zeros([64, 64], dtype=torch.float)

    for gamma in range(64):
        for k in range(64):
            if gamma == k:
                S_t[gamma, k] = q[k]

    return S_t


def B(shape, block_size):
    blocks_shape = (shape[0] // block_size[0], shape[1] // block_size[1])

    B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], dtype=torch.float)

    for s_x in range(shape[0]):
        for s_y in range(shape[1]):
            for x in range(blocks_shape[0]):
                for y in range(blocks_shape[1]):
                    for i in range(block_size[0]):
                        for j in range(block_size[1]):
                            if x * block_size[0] + i == s_x and y * block_size[1] + j == s_y:
                                B_t[s_x, s_y, x, y, i, j] = 1.0

    return B_t

Layers

The layers from the inference notes are copied here with the exception of batch normalization. For training, the gradient for all layers is computed using autograd, so there is no need to program a backward pass explicitly. Batch normalization, however, has a different formulation at training time that we need to take into account.

In [0]:
class AvgPool(torch.nn.modules.Module):
    def __init__(self):
        super(AvgPool, self).__init__()

    def forward(self, input):
        result = torch.mean(input[:, :, :, :, 0].view(-1, input.shape[1], input.shape[2]*input.shape[3]), 2)
        return result

class Conv2d(torch.nn.modules.Module):
    def __init__(self, conv_spatial, J):
        super(Conv2d, self).__init__()

        self.stride = conv_spatial.stride
        self.padding = conv_spatial.padding
        self.weight = torch.nn.Parameter(conv_spatial.weight.clone())

        self.register_buffer('J', J[0])
        self.register_buffer('J_i', J[1])

        J_batched = self.J_i.contiguous().view(np.prod(self.J_i.shape[0:3]), 1, *self.J_i.shape[3:5])
        self.register_buffer('J_batched', J_batched)

        self.make_apply_op()

        self.jpeg_op = None

    def make_apply_op(self):
        input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]]
        jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]]

        self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc',  jpeg_op_shape, self.J, input_shape, constants=[1], optimize='optimal')
        self.apply_conv.evaluate_constants(backend='torch')

    def _apply(self, fn):
        s = super(Conv2d, self)._apply(fn)
        s.make_apply_op()
        return s

    def explode(self):
        out_channels = self.weight.shape[0]
        in_channels = self.weight.shape[1]

        jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
        jpeg_op = jpeg_op.permute(1, 0, 2, 3)
        jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J_i.shape[0:3], *(np.array(self.J_i.shape[3:5]) // self.stride))

        return jpeg_op

    def explode_pre(self):
        self.jpeg_op = self.explode()

    def forward(self, input):
        if self.jpeg_op is not None:
            jpeg_op = self.jpeg_op
        else:
            jpeg_op = self.explode()

        return self.apply_conv(jpeg_op, input, backend='torch')


class ASMReLU(torch.nn.modules.Module):
    def __init__(self, n_freqs):
        super(ASMReLU, self).__init__()
        C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
        self.register_buffer('C_n', C_n)

        Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
        self.register_buffer('Hm', Hm)

        self.make_masking_ops()

    def make_masking_ops(self):
        self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
        self.annm_op.evaluate_constants(backend='torch')

        self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
        self.hsm_op.evaluate_constants(backend='torch')

    def _apply(self, fn):
        s = super(ASMReLU, self)._apply(fn)
        s.make_masking_ops()
        return s

    def annm(self, x):
        appx_im = self.annm_op(x, backend='torch')
        mask = torch.zeros_like(appx_im)
        mask[appx_im >= 0] = 1
        return mask

    def half_spatial_mask(self, x, m):
        return self.hsm_op(x, m, backend='torch')

    def forward(self, input):
        annm = self.annm(input)
        out_comp = self.half_spatial_mask(input, annm)
        return out_comp

Batch Normalization Revisited

The previous definition of batch normaliztion was suitable for inference but not for training. It used precomputed $\gamma$ and $\beta$ parameters as well as the running mean and variance. To compute this ourselves, is straightforward using the batch mean and variance that was derived in the batch normalization notes.

In [0]:
class BatchNorm(torch.nn.modules.Module):
    def __init__(self, bn):
        super(BatchNorm, self).__init__()

        self.register_buffer('running_mean', bn.running_mean.clone())
        self.register_buffer('running_var', bn.running_var.clone())
        self.register_buffer('S_i', S_i())

        self.momentum = bn.momentum
        self.eps = bn.eps

        self.gamma = torch.nn.Parameter(bn.weight.clone())
        self.beta = torch.nn.Parameter(bn.bias.clone())

    def forward(self, input):
        if self.training:
            channels = input.shape[1]
            
            input_channelwise = input.permute(1, 0, 2, 3, 4).clone()
            
            # Compute the batch mean for each channel
            block_means = input_channelwise[:, :, :, :, 0].contiguous().view(channels, -1)
            batch_mean = torch.mean(block_means, 1)

            # Compute the batch variance for each channel
            input_dequantized = torch.einsum('mtxyk,gk->mtxyg', [input_channelwise, self.S_i])
            input_dequantized[:, :, :, :, 0] = 0  # zero mean
            block_variances = torch.mean(input_dequantized**2, 4).view(channels, -1)
            batch_var = torch.mean(block_variances + block_means**2, 1) - batch_mean**2
            
            # Apply bessel correction to match pytorch i dont think this is really necessary 
            bessel_correction_factor = input.shape[0] * input.shape[2] * input.shape[3] * 64
            bessel_correction_factor = bessel_correction_factor / (bessel_correction_factor - 1)
            batch_var *= bessel_correction_factor
            batch_var = batch_var
            
            # Update running stats
            self.running_mean = self.running_mean * (1 - self.momentum) + batch_mean * self.momentum
            self.running_var = self.running_var * (1 - self.momentum) + batch_var * self.momentum

            # Apply parameters
            invstd = 1. / torch.sqrt(batch_var + self.eps).view(1, -1, 1, 1, 1)
            mean = batch_mean.view(1, -1, 1, 1)
        else:
            invstd = 1. / torch.sqrt(self.running_var + self.eps).view(1, -1, 1, 1, 1)
            mean = self.running_mean.view(1, -1, 1, 1)
           
        g = self.gamma.view(1, -1, 1, 1, 1)
        b = self.beta.view(1, -1, 1, 1)
        
        input[:, :, :, :, 0] = input[:, :, :, :, 0] - mean
        input = input * invstd
        input = input * g
        input[:, :, :, :, 0] = input[:, :, :, :, 0] + b

        return input

Sanity Check

The first test we perform is a simple sanity check. We generate a small batch of random images and JPEG compress them. Then we perform training over the batch and show that the means, variances, and learned paramters are the same for the spatial and JPEG model.

First we generate the batch

In [0]:
def show_image(m, ax=None):
    c_img = np.zeros((m.shape[0], m.shape[1], 3))
    
    max_gr0 = np.max(m[m > 0])
    
    if len(m[m < 0]) > 0:
        min_le0 = np.min(m[m < 0])
    else:
        min_le0 = 0
    
    c_img[m < 0] = np.array([[0.0, 1.0, 0.0]]) * (m[m < 0] / min_le0).reshape(-1, 1)
    c_img[m == 0] = np.array([0.0, 0.0, 1.0]) 
    c_img[m > 0] = np.array([[1.0, 0.0, 0.0]]) * (m[m > 0] / max_gr0).reshape(-1, 1)
    
    plt.grid(False)
    
    if ax is None:
        return plt.imshow(c_img)
    else:
        ax.grid(False)
        return ax.imshow(c_img)
    

def generate_batch(n, c):
    return torch.Tensor(np.random.randint(0, 255, size=(n, c, 8, 8)).astype(float))



def show_batch(batch):
    plt.figure(figsize=(15, 10))
    for b in range(batch.shape[0]):
        for c in range(batch.shape[1]):
            plt.subplot(batch.shape[0], batch.shape[1], b * batch.shape[1] + c + 1)
            show_image(batch[b, c, :, :])
            
            
spatial_batch = [generate_batch(128, 16) for _ in range(300)]

Then we create the JPEG compresed batch

In [0]:
C = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S()))
C_i = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S_i()))


def codec(image_size, block_size=(8, 8)):
    B_i = B(image_size, block_size)
    J = torch.einsum('srxyij,ijk->srxyk', (B_i, C))
    J_i = torch.einsum('srxyij,ijk->xyksr', (B_i, C_i))
    return J, J_i


def encode(batch, block_size=(8, 8), device=None):
    J, _ = codec(batch.shape[2:], block_size)

    if device is not None:
        batch = batch.to(device)
        J = J.to(device)

    jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
    return jpeg_batch


def decode(batch, device=None):
    block_size = int(np.sqrt(batch.shape[4]))
    image_size = (int(batch.shape[2] * block_size), int(batch.shape[3] * block_size))

    _, J_i = codec(image_size, (block_size, block_size))

    if device is not None:
        batch = batch.to(device)
        J_i = J_i.to(device)
        
    spatial_batch = torch.einsum('xyksr,ncxyk->ncsr', (J_i, batch))
    return spatial_batch


jpeg_batch = [encode(sb) for sb in spatial_batch]

We initialize a batch norm layer for spatial and JPEG

In [0]:
spatial_bn = torch.nn.BatchNorm2d(spatial_batch[0].shape[1])
jpeg_bn = BatchNorm(spatial_bn)

Then perform training over the batch

In [0]:
spatial_bn.train()
jpeg_bn.train()

spatial_res = [spatial_bn(sb) for sb in spatial_batch[:-2]]
jpeg_res = [jpeg_bn(jb) for jb in jpeg_batch[:-2]] 

Finally we print the means, variances, and parameters and show that they are identical.

In [15]:
print(spatial_bn.running_mean)
print(jpeg_bn.running_mean)

print(spatial_bn.running_var)
print(jpeg_bn.running_var)

print(spatial_bn.weight)
print(jpeg_bn.gamma)

print(spatial_bn.bias)
print(jpeg_bn.beta)
tensor([126.8763, 126.8900, 126.7725, 126.9456, 127.0475, 127.0415, 127.2132,
        126.7650, 127.0972, 126.9396, 126.9821, 127.1473, 126.9219, 127.2403,
        127.1094, 126.8491])
tensor([126.8763, 126.8900, 126.7725, 126.9456, 127.0475, 127.0414, 127.2132,
        126.7650, 127.0972, 126.9396, 126.9820, 127.1472, 126.9219, 127.2403,
        127.1094, 126.8490])
tensor([5401.3335, 5402.7729, 5415.4565, 5400.1646, 5419.3457, 5415.6445,
        5418.3081, 5427.9434, 5427.9272, 5417.5508, 5417.3433, 5397.3442,
        5422.1782, 5416.1758, 5426.9795, 5411.3955])
tensor([5401.3325, 5402.7729, 5415.4551, 5400.1626, 5419.3438, 5415.6436,
        5418.3071, 5427.9424, 5427.9243, 5417.5493, 5417.3428, 5397.3428,
        5422.1777, 5416.1738, 5426.9785, 5411.3940])
Parameter containing:
tensor([0.2110, 0.7162, 0.7512, 0.7032, 0.3454, 0.1638, 0.8307, 0.7682, 0.0744,
        0.8562, 0.0685, 0.8670, 0.7056, 0.0095, 0.9275, 0.2877],
       requires_grad=True)
Parameter containing:
tensor([0.2110, 0.7162, 0.7512, 0.7032, 0.3454, 0.1638, 0.8307, 0.7682, 0.0744,
        0.8562, 0.0685, 0.8670, 0.7056, 0.0095, 0.9275, 0.2877],
       requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)

Training Test

We conclude by defining the same ResNet based network that was used in the inference section. We then train this model on the JPEG compressed MNIST dataset and show that it is converging and that its accuracy matches the spatial domain model.

First the block and network definitions for JPEG

In [0]:
class JpegResBlock(torch.nn.Module):
    def __init__(self, spatial_resblock, n_freqs, J_in, J_out, relu_layer=ASMReLU):
        super(JpegResBlock, self).__init__()

        J_down = (J_out[0], J_in[1])

        self.conv1 = Conv2d(spatial_resblock.conv1, J_down)
        self.conv2 = Conv2d(spatial_resblock.conv2, J_out)

        self.bn1 = BatchNorm(spatial_resblock.bn1)
        self.bn2 = BatchNorm(spatial_resblock.bn2)

        self.relu = relu_layer(n_freqs=n_freqs)

        if spatial_resblock.downsampler is not None:
            self.downsampler = Conv2d(spatial_resblock.downsampler, J_down)
            self.bn_ds = BatchNorm(spatial_resblock.bn_ds)
        else:
            self.downsampler = None

    def explode_all(self):
        self.conv1.explode_pre()
        self.conv2.explode_pre()

        if self.downsampler is not None:
            self.downsampler.explode_pre()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsampler is not None:
            residual = self.downsampler(x)
            residual = self.bn_ds(residual)
        else:
            residual = x

        out += residual
        out = self.relu(out)
        return out
    
    
class JpegResNet(torch.nn.Module):
    def __init__(self, spatial_model, n_freqs, relu_layer=ASMReLU):
        super(JpegResNet, self).__init__()

        J_32 = codec((32, 32))
        J_16 = codec((16, 16))
        J_8 = codec((8, 8))

        self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32, relu_layer=relu_layer)
        self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16, relu_layer=relu_layer)
        self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8, relu_layer=relu_layer)

        self.averagepooling = AvgPool()
        self.fc = spatial_model.fc

    def explode_all(self):
        self.block1.explode_all()
        self.block2.explode_all()
        self.block3.explode_all()

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)

        out = self.averagepooling(out)
        out = out.view(x.size(0), -1)

        out = self.fc(out)

        return out

and the corresponding defintions for spatial

In [0]:
class SpatialResBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True):
        super(SpatialResBlock, self).__init__()

        self.downsample = downsample

        stride = 2 if downsample else 1

        self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.bn1 = torch.nn.BatchNorm2d(out_channels)
        self.bn2 = torch.nn.BatchNorm2d(out_channels)
        self.relu = torch.nn.ReLU(inplace=True)

        if downsample or in_channels != out_channels:
            stride = 2 if downsample else 1
            self.downsampler = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
            self.bn_ds = torch.nn.BatchNorm2d(out_channels)
        else:
            self.downsampler = None

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsampler is not None:
            residual = self.downsampler(x)
            residual = self.bn_ds(residual)
        else:
            residual = x

        out += residual
        out = self.relu(out)

        return out
    
    
class SpatialResNet(torch.nn.Module):
    def __init__(self, channels, classes):
        super(SpatialResNet, self).__init__()

        self.block1 = SpatialResBlock(in_channels=channels, out_channels=16, downsample=False)
        self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
        self.block3 = SpatialResBlock(in_channels=32, out_channels=64)

        self.averagepooling = torch.nn.AvgPool2d(8, stride=1)
        self.fc = torch.nn.Linear(64, classes)

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)

        out = self.averagepooling(out)
        out = out.view(x.size(0), -1)

        out = self.fc(out)

        return out

Utilities for training and testing

In [0]:
def train(model, device, train_loader, optimizer, epoch, doencode):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if doencode:
            data, target = encode(data, device=device), target.to(device)
        else:
            data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader, doencode):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            if doencode:
                data, target = encode(data, device=device), target.to(device)
            else:
                data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    return correct / len(test_loader.dataset)

Loading the MNIST dataset

In [19]:
train_data = datasets.MNIST('MNIST-data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Pad(2),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

test_data = datasets.MNIST('MNIST-data', train=False, transform=transforms.Compose([
                       transforms.Pad(2),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!

and converting it to JPEG

In [0]:
class MNISTJpegDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, labels):
        self.dataset = dataset
        self.labels = labels

    def __len__(self):
        return self.dataset.size()[0]

    def __getitem__(self, idx):
        _, l  = self.labels[idx]
        return self.dataset[idx, :, :, :, :], l

device = torch.device('cuda')
J = torch.einsum('srxyij,ijab,abg,gk->srxyk', (B((32, 32), (8, 8)), D(), Z(), S())).to(device)


def jpeg_encode(batch):
    batch = batch.to(device)
    jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
    return jpeg_batch

test_jpegconvert_loader = torch.utils.data.DataLoader(test_data, batch_size=10000, shuffle=False)
test_jpeg_data = []
    
for data, _ in test_jpegconvert_loader:
    jpeg_data = jpeg_encode(data)
    test_jpeg_data.append(jpeg_data)
    
test_jpeg_data = torch.cat(test_jpeg_data)

jpeg_test_data = MNISTJpegDataset(test_jpeg_data, test_data)
jpeg_test_loader = torch.utils.data.DataLoader(jpeg_test_data, batch_size=128, shuffle=False)

train_jpegconvert_loader = torch.utils.data.DataLoader(train_data, batch_size=60000, shuffle=False)
train_jpeg_data = []
    
for data, _ in train_jpegconvert_loader:
    jpeg_data = jpeg_encode(data)
    train_jpeg_data.append(jpeg_data)
    
train_jpeg_data = torch.cat(train_jpeg_data)

jpeg_train_data = MNISTJpegDataset(train_jpeg_data, train_data)
jpeg_train_loader = torch.utils.data.DataLoader(jpeg_train_data, batch_size=128, shuffle=False)

And finally training and testing. We only do 1 epoch as that is all that is needed to get resonable performance on MNIST.

In [27]:
model = SpatialResNet(1, 10).to(device)
jpeg_model = JpegResNet(model, n_freqs=14).to(device)

jpeg_optimizer = torch.optim.Adam(jpeg_model.parameters())
spatial_optimizer = torch.optim.Adam(model.parameters())

for epoch in range(1):
    train(jpeg_model, device, jpeg_train_loader, jpeg_optimizer, epoch, doencode=False)
    test(jpeg_model, device, jpeg_test_loader, doencode=False)

    train(model, device, train_loader, spatial_optimizer, epoch, doencode=False)
    test(model, device, test_loader, doencode=False)
Train Epoch: 0 [0/60000 (0%)]	Loss: 2.282061
Train Epoch: 0 [12800/60000 (21%)]	Loss: 0.939266
Train Epoch: 0 [25600/60000 (43%)]	Loss: 0.303440
Train Epoch: 0 [38400/60000 (64%)]	Loss: 0.194922
Train Epoch: 0 [51200/60000 (85%)]	Loss: 0.170025

Test set: Average loss: 0.1713, Accuracy: 9567/10000 (96%)

Train Epoch: 0 [0/60000 (0%)]	Loss: 2.285923
Train Epoch: 0 [12800/60000 (21%)]	Loss: 0.403069
Train Epoch: 0 [25600/60000 (43%)]	Loss: 0.184515
Train Epoch: 0 [38400/60000 (64%)]	Loss: 0.165370
Train Epoch: 0 [51200/60000 (85%)]	Loss: 0.147846

Test set: Average loss: 0.1179, Accuracy: 9696/10000 (97%)

© 2018 Max Ehrlich