ResNet Inference

These notes demonstrate taking a pre-trained model, one that was trained on spatial domain images, and converting it to perform inference in the JPEG transform domain. A simple demonstration is provided using a small ResNet and the MNIST dataset. The following code is boilerplate.

In [0]:
!pip install torch torchvision
!pip install opt_einsum

import torch.nn as nn
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import opt_einsum as oe

torch.backends.cudnn.enabled = False
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (0.4.1)
Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.2.1)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (5.3.0)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.14.6)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.6/dist-packages (2.3.1)
Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.6/dist-packages (from opt_einsum) (1.14.6)

Data

First step is to get some data. For a simple and small problem, we use MNIST. The $28 \times 28$ images are zero padded to $32 \times 32$ during loading so that they form even $8 \times 8$ JPEG blocks, this will be important later. The train and test sets are loaded using PyTorch and 10 random train and test images are displayed to show that the data was loaded correctly.

In [0]:
def display_digit(x, y):
    plt.title('Label: {}'.format(y))
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(x, 'gray_r')
    
train_data = datasets.MNIST('MNIST-data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Pad(2),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

test_data = datasets.MNIST('MNIST-data', train=False, transform=transforms.Compose([
                       transforms.Pad(2),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

plt.figure(figsize=(20, 3))
plt.suptitle('Random MNIST Training Images')
for i in range(10):
    plt.subplot(1, 10, i+1)
    ind = np.random.randint(0, len(train_data))
    display_digit(train_data[ind][0].numpy().reshape(32, 32), train_data[ind][1].numpy())
    
plt.figure(figsize=(20, 3))
plt.suptitle('Random MNIST Testing Images')
for i in range(10):
    plt.subplot(1, 10, i+1)
    ind = np.random.randint(0, len(test_data))
    display_digit(test_data[ind][0].numpy().reshape(32, 32), test_data[ind][1].numpy())

Training

Next, we train a highly simplified version of ResNet on the MNIST dataset. Since this test is for demonstration only, this simple resnet will suffice. Also since MNIST is a very simple problem, it will be easily solved even by this simple network. The network uses only two residual blocks with each residual block performing downsampling. The output of the final residual block is average-pooled and a single fully connected layer learns the classification. Note that MNIST can be solved effecively with even simpler ResNets but they converge slower (and are less interesting). Note also that there is nothing happening with compressed images yet.

First, the residual block and network architecture are implemented using the PyTorch object-oriented convention.

In [0]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True):
        super(ResBlock, self).__init__()
        
        self.downsample = downsample
        
        stride = 2 if downsample else 1
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
      
        if downsample:
            self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample:
            residual = self.downsampler(x)
        else:
            residual = x

        out += residual
        out = self.relu(out)

        return out
    
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        
        self.block1 = ResBlock(in_channels=1, out_channels=16, downsample=False)
        self.block2 = ResBlock(in_channels=16, out_channels=32)
        self.block3 = ResBlock(in_channels=32, out_channels=64)
        
        self.averagepooling = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64, 10)
        
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        
        out = self.averagepooling(out)
        out = out.view(x.size(0), -1)

        out = self.fc(out)
        
        return out

Then, helper functions to carry out training and testing the model are given

In [0]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward(retain_graph=True)
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

Finally, the model is trained for 5 epochs using a GPU. The final test set accuracy should be in the high 90s and it should train in no more than a few minutes.

In [0]:
device = torch.device('cuda')

model = ResNet().to(device)
optimizer = optim.Adam(model.parameters())

for epoch in range(5):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
Train Epoch: 0 [0/60000 (0%)]	Loss: 2.306411
Train Epoch: 0 [12800/60000 (21%)]	Loss: 0.798400
Train Epoch: 0 [25600/60000 (43%)]	Loss: 0.260864
Train Epoch: 0 [38400/60000 (64%)]	Loss: 0.165256
Train Epoch: 0 [51200/60000 (85%)]	Loss: 0.115742

Test set: Average loss: 0.1278, Accuracy: 9717/10000 (97%)

Train Epoch: 1 [0/60000 (0%)]	Loss: 0.060571
Train Epoch: 1 [12800/60000 (21%)]	Loss: 0.102701
Train Epoch: 1 [25600/60000 (43%)]	Loss: 0.074513
Train Epoch: 1 [38400/60000 (64%)]	Loss: 0.064179
Train Epoch: 1 [51200/60000 (85%)]	Loss: 0.051954

Test set: Average loss: 0.1269, Accuracy: 9628/10000 (96%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.113030
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.065610
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.033204
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.019927
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.033717

Test set: Average loss: 0.0559, Accuracy: 9837/10000 (98%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.019948
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.019733
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.028246
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.038952
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.017133

Test set: Average loss: 0.0379, Accuracy: 9883/10000 (99%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.049492
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.009796
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.019557
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.062875
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.031138

Test set: Average loss: 0.0963, Accuracy: 9697/10000 (97%)

GPU JPEG Codec

Before we can carry out and inference on JPEG compressed images, we need to convert the MNIST data. We can leverage the Tensor method developed previously along with PyTorch to create a fast GPU enabled JPEG compression codec to quickly convert the images. The below code is adapted from the "Tensor Methods" notes to use PyTorch GPU tensors, and the JPEG compression and decompression tensors are computed. Remember that this operation took several minutes on CPU. This implementation forms the tensors in less than a second.

In [0]:
def A(alpha):
    if alpha == 0:
        return 1.0 / np.sqrt(2)
    else:
        return 1
    
    
def D():    
    D_t = torch.zeros([8, 8, 8, 8], device=device, dtype=torch.float)
    
    for i in range(8):
        for j in range(8):
            for alpha in range(8):
                for beta in range(8):
                    scale_a = A(alpha)
                    scale_b = A(beta)
                    
                    coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
                    coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
                    
                    D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y    

    return D_t
    

def D_n(n_freqs):    
    D_t = torch.zeros([8, 8, 8, 8], device=device, dtype=torch.float)
    
    for i in range(8):
        for j in range(8):
            for alpha in range(8):
                for beta in range(8):
                    if alpha + beta <= n_freqs:
                        scale_a = A(alpha)
                        scale_b = A(beta)

                        coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
                        coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)

                        D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y    

    return D_t
    
    
def Z():
    z = np.array([[ 0,  1,  5,  6, 14, 15, 27, 28],
                  [ 2,  4,  7, 13, 16, 26, 29, 42],
                  [ 3,  8, 12, 17, 25, 30, 41, 43],
                  [ 9, 11, 18, 24, 31, 40, 44, 53],
                  [10, 19, 23, 32, 39, 45, 52, 54],
                  [20, 22, 33, 38, 46, 51, 55, 60],
                  [21, 34, 37, 47, 50, 56, 59, 61],
                  [35, 36, 48, 49, 57, 58, 62, 63]], dtype=float)
    
    Z_t = torch.zeros([8, 8, 64], device=device, dtype=torch.float) 
   
    for alpha in range(8):
        for beta in range(8):
            for gamma in range(64):
                if z[alpha, beta] == gamma:
                    Z_t[alpha, beta, gamma] = 1         
    
    return Z_t


def S():    
    q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27, 
                  27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29, 
                  29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34, 
                  35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
    
    S_t = torch.zeros([64, 64], device=device, dtype=torch.float)
    
    for gamma in range(64):
        for k in range(64):
            if gamma == k:
                S_t[gamma, k] = 1.0 / q[k]            
                
    return S_t


def S_i():    
    q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27, 
                  27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29, 
                  29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34, 
                  35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
    
    S_t = torch.zeros([64, 64], device=device, dtype=torch.float)
    
    for gamma in range(64):
        for k in range(64):
            if gamma == k:
                S_t[gamma, k] = q[k] 
                
    return S_t
                
                
def B(shape, block_size):
    blocks_shape = (shape[0] // block_size[0], shape[1] // block_size[1])
    
    B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], device=device, dtype=torch.float)
    
    for s_x in range(shape[0]):
        for s_y in range(shape[1]):
            for x in range(blocks_shape[0]):
                for y in range(blocks_shape[1]):
                    for i in range(block_size[0]):
                        for j in range(block_size[1]):
                            if x * block_size[0] + i == s_x and y * block_size[1] + j == s_y:
                                B_t[s_x, s_y, x, y, i ,j] = 1.0
    
    return B_t

J = torch.einsum('srxyij,ijab,abg,gk->srxyk', (B((32, 32), (8, 8)), D(), Z(), S()))
J_i = torch.einsum('srxyij,ijab,abg,gk->xyksr', (B((32, 32), (8, 8)), D(), Z(), S_i()))

print(J.size())
print(J_i.size())
torch.Size([32, 32, 4, 4, 64])
torch.Size([4, 4, 64, 32, 32])

Next the images are encoded using the tensors. Using GPU parallelization, the entirety of each dataset can be converted in a single step. All 70,000 MNIST images are converted in just a few seconds.

In [0]:
def jpeg_encode(batch):
    batch = batch.to(device)
    jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
    return jpeg_batch

test_jpegconvert_loader = torch.utils.data.DataLoader(test_data, batch_size=10000, shuffle=False)
test_jpeg_data = []
    
for data, _ in test_jpegconvert_loader:
    jpeg_data = jpeg_encode(data)
    test_jpeg_data.append(jpeg_data)
    
test_jpeg_data = torch.cat(test_jpeg_data)

train_jpegconvert_loader = torch.utils.data.DataLoader(train_data, batch_size=60000, shuffle=False)
train_jpeg_data = []
    
for data, _ in train_jpegconvert_loader:
    jpeg_data = jpeg_encode(data)
    train_jpeg_data.append(jpeg_data)
    
train_jpeg_data = torch.cat(train_jpeg_data)

To verify that things are working, ten random test images are decoded and printed along with their label.

In [0]:
plt.figure(figsize=(20, 3))
plt.suptitle('Reconstructed After JPEG Convert')
for i in range(10):
    plt.subplot(1, 10, i+1)
    ind = np.random.randint(0, len(test_data))
    
    current_jpeg = test_jpeg_data[ind, :, :, :, :].view(4, 4, 64)
    current_recn = torch.einsum('xyksr,xyk->sr', (J_i, current_jpeg)).cpu()
    
    display_digit(current_recn.numpy().reshape(32, 32), test_data[ind][1].numpy())

As a last step, a pytorch DataSet class is prepared for the inference pipeline. Since the compressed JPEGs are all stored in a single GPU tensor, the dataset is extremely simple

In [0]:
class MNISTJpegDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, labels):
        self.dataset = dataset
        self.labels = labels

    def __len__(self):
        return self.dataset.size()[0]

    def __getitem__(self, idx):
        _, l  = self.labels[idx]
        return self.dataset[idx, :, :, :, :], l

Model Conversion

Now it is time to convert the learned parameters to their tensor forms. The first thing we need for this is a set of JPEG compression and decompression tensors. This is because the ResNet model resizes the input images, which will change the number of JPEG blocks. Since only the block structure is changing, this process can be sped up by precomputing the compression part of the transform (DCT, zigzag, quantization) which doesn't change, then applying the different size blocking operations to it.

In [0]:
C = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S()))
C_i = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S_i()))

B_32 = B((32, 32), (8, 8))
B_16 = B((16, 16), (8, 8))
B_8 = B((8, 8), (8, 8))


J_32 = torch.einsum('srxyij,ijk->srxyk', (B_32, C))
J_32_i = torch.einsum('srxyij,ijk->xyksr', (B_32, C_i))

J_16 = torch.einsum('srxyij,ijk->srxyk', (B_16, C))
J_16_i = torch.einsum('srxyij,ijk->xyksr', (B_16, C_i))

J_8 = torch.einsum('srxyij,ijk->srxyk', (B_8, C))
J_8_i = torch.einsum('srxyij,ijk->xyksr', (B_8, C_i))

J_32 = (J_32, J_32_i)
J_16 = (J_16, J_16_i)
J_8 = (J_8, J_8_i)

Next we define classes for each layer of the ResNet architecture. These will convert the paramters learned by the spatial model to operate on the JPEG transformed inputs. This gets pretty complicated but the basic building blocks are identical to the algorithms developed in previous notes. The only notable additions are some alterations to the shape of the inputs and outputs which are now multichannel batches of JPEG transformed images at each step.

Convolution Layer

The convolution layer is taken almost directlly from the "Tensor Methods" notes. In addition to the learned weights that will be converted, the input image shape is provided and JPEG encoding and decoding tensors are given. One noteable addition is that the exploded convolution has one additional channel to compute all the output convolutions at the same time. For example, in the first layer there is 1 input channel and 16 output channels. This is accomplished with a single $16 \times 1 \times 32 \times 32 \times 16 \times 16$ tensor where the input and output image sizes in $H \times W$ are $32 \times 32$ and $16 \times 16$ respectively. After expoding, the convolution is combined with the provided JPEG encoding and decoding tensors (which could be of different sizes since the layer might be downsampling) to give the final tensor. Note that we are treating each channel as a JPEG compressed "image". Finally, the forward function applies the tensor operator to a JPEG compressed input. This main difference here is the addition of channels and batches as indices $m, n, t$. where $t$ is the batch index, $n$ is the input channel index and $m$ is the output channel index.

In [0]:
class JpegConv2d(torch.nn.modules.Module):
    def __init__(self, conv_spatial, J):
        super(JpegConv2d, self).__init__()

        self.stride = conv_spatial.stride
        self.weight = conv_spatial.weight
        self.padding = conv_spatial.padding
        
        self.J = J
        self.J_batched = self.J[1].contiguous().view(np.prod(self.J[1].shape[0:3]), 1, *self.J[1].shape[3:5])
        
        input_shape = [0, self.weight.shape[1], *self.J[1].shape[0:3]]        
        jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J[1].shape[0:3], *self.J[0].shape[0:2]]     
        
        self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc',  jpeg_op_shape, self.J[0], input_shape, constants=[1], optimize='optimal')
        self.apply_conv.evaluate_constants(backend='torch')                                                                     
        
    def forward(self, input):
        jpeg_op = []
        
        out_channels = self.weight.shape[0]
        in_channels = self.weight.shape[1]
        
        jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
        jpeg_op = jpeg_op.permute(1, 0, 2, 3)
        jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J[1].shape[0:3], *(np.array(self.J[1].shape[3:5]) // self.stride))

        return self.apply_conv(jpeg_op, input, backend='torch')

Batch Normalization

Batch normalization is also quite similar to the notes. Since this layer is for inference only, the $\gamma$ and $\beta$ parameters are converted accordingly. They are also reshaped to allow them to be applied to the input, which is now a batched multi-channel input with several JPEG blocks. Automatic broadcasting cannot handle this case without help. One important difference is in the application of $\beta$. Remember that when applying $\beta$ to DCT coefficients, we needed to multiply by 8 beforehand. This is not neccessary after the full JPEG transform because the quantization coefficient is already 8, so the value stored at the DC coefficient is the exact mean. To verify this, look at the S() or S_i() functions presented in the "GPU JPEG Codec" section, and note that the first coefficient in the quantization matrix q is 8.

In [0]:
class JpegBatchNorm(torch.nn.modules.Module):
    def __init__(self, bn):
        super(JpegBatchNorm, self).__init__()
        
        self.mean = bn.running_mean
        self.var = bn.running_var
        
        self.gamma = bn.weight
        self.beta = bn.bias

        self.gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
        self.beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
        
    def forward(self, input):
        input = input * self.gamma_final
        input[:, :, :, :, 0] = input[:, :, :, :, 0] + self.beta_final
        return input

ReLu

The multilinear ReLu approximation from the ReLu notes is implented here with some additions to account for the other steps of the JPEG transform, since those notes were developed around only the DCT. Instead of a DCT approximation that keeps only the $n$ lowest frequencies, we combine that with the inverse zigzag and quantization steps to give an approximate decompression tensor. Note that this will preserve the block structure of the image, e.g. for a $32 \times 32$ original image, after decompression with this tensor will be $4 \times 4 \times 8 \times 8$, there is no need to undo the block structure for the purposes of this operation. Also note that the harmonic mixing tensor is combined with both forward and reverse zigzag and quantization so that it too can operate directly on JPEG transformed data.

In [0]:
class JpegRelu(torch.nn.modules.Module):
    def __init__(self, n_freqs):
        super(JpegRelu, self).__init__()
        self.C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
        self.Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
        
        self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
        self.annm_op.evaluate_constants(backend='torch')
        
        self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal') 
        self.hsm_op.evaluate_constants(backend='torch')
    
    def annm(self, x):
        appx_im = self.annm_op(x, backend='torch')
        mask = torch.zeros_like(appx_im)
        mask[appx_im >= 0] = 1
        return mask
    
    def half_spatial_mask(self, x, m):
        return self.hsm_op(x, m, backend='torch')
        
    def forward(self, input):
        annm = self.annm(input)
        out_comp = self.half_spatial_mask(input, annm)
        return out_comp

Global Average Pooling

Global average pooling, and essential part of the ResNet architecture, has not been discussed in any notes so far. Average pooling in general is quite simple to implement in the JPEG transform domain since it is a linear operation, however the global average pooling can be done far more efficiently than in the general case. Recall that, as shown in the "Batch Normalization" section of these notes, the first element of the encoded blocks is exactly the mean of that block. Since all blocks are the same size, the mean of the image is equal to the mean of the individual mean of each block. Therefore, we need only extract the first element of each block in the final result of the network and average it per channel to get the global average pooling result. This is a massive optimization over the spatial domain algorithm, and is quite simple to implement. It also avoids doing any kind of decompression before feeding the result of the JPEG domain network into the fully connected layer. The information that it needs is nicely encapsulated by the JPEG representation.

In [0]:
class JpegAvgPool(torch.nn.modules.Module):
    def __init__(self):
        super(JpegAvgPool, self).__init__()
        
    def forward(self, input):
        result = torch.mean(input[:, :, :, :, 0].view(-1, input.shape[1], input.shape[2]*input.shape[3]), 2)
        return result

ResNet

Next the residual block structure and final network are constructed from these parts, this mirrors the structure of the spatial domain model that was trained earlier. By default we are keeping a full 10 spatial frequencies for the ReLu approximations, this can be tuned but it will effect the accuracy of inference since the model was trained with an exact ReLu, the network weights are not designed to handle this approximation.

In [0]:
class JpegResBlock(nn.Module):
    def __init__(self, spatial_resblock, n_freqs, J_in, J_out):
        super(JpegResBlock, self).__init__()
        
        J_down = (J_out[0], J_in[1])
        
        self.conv1 = JpegConv2d(spatial_resblock.conv1, J_down)
        self.conv2 = JpegConv2d(spatial_resblock.conv2, J_out)
        
        self.bn1 = JpegBatchNorm(spatial_resblock.bn1)
        self.bn2 = JpegBatchNorm(spatial_resblock.bn2)
        
        self.relu = JpegRelu(n_freqs=n_freqs)
      
        if spatial_resblock.downsample:
            self.downsample = True
            self.downsampler = JpegConv2d(spatial_resblock.downsampler, J_down)
        else:
            self.downsample = False
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample:
            residual = self.downsampler(x)
        else:
            residual = x
        
        out += residual
        out = self.relu(out)
        return out

            
class JpegResNet(nn.Module):
    def __init__(self, spatial_model, exact=False):
        super(JpegResNet, self).__init__()
        
        if exact:
            n_freqs = 14
        else:
            n_freqs = 10
        
        self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32)       
        self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16)
        self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8)
        
        self.averagepooling = JpegAvgPool()
        self.fc = spatial_model.fc 
        
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        
        out = self.averagepooling(out)
        out = out.view(x.size(0), -1)

        out = self.fc(out)
        
        return out

Testing

Finally, the model is ready to be converted. Because of the way the code was written, this is as simple as providing the spatial domain model as an argument. The returned model performs the same inference on JPEGs to within the ReLu approximation error.

In [0]:
jpeg_model = JpegResNet(model)

To demonstrate the model, we test it using the previously converted MNIST test images. The result should be close to, but not exactly the same as, the spatial domain test accuracy.

In [0]:
jpeg_test_data = MNISTJpegDataset(test_jpeg_data, test_data)
jpeg_test_loader = torch.utils.data.DataLoader(jpeg_test_data, batch_size=128, shuffle=False)

test(jpeg_model, device, jpeg_test_loader)
Test set: Average loss: 0.0934, Accuracy: 9702/10000 (97%)

Next the same test is repeated using exact ReLu (14 spatial frequencies used for approximation). This should give the exact result that the spatial model gave.

In [0]:
jpeg_model_exact = JpegResNet(model, exact=True)
test(jpeg_model_exact, device, jpeg_test_loader)
Test set: Average loss: 0.0963, Accuracy: 9698/10000 (97%)

In [0]:
import time

def test_timing(model, device, test_loader):
    model.eval()
    t0 = time.perf_counter()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            model(data)
    
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    print('Time: {:.02e}s'.format(t1 - t0))
          
test_timing(model, device, test_loader)
test_timing(jpeg_model, device, jpeg_test_loader)
Time: 5.72e+00s
Time: 4.46e+01s
In [0]:
def model_size(model):
    return np.sum([np.sum([np.prod(p.shape) for p in m.parameters()]) for m in model.modules()]) * 4

print(model_size(model))
print(model_size(jpeg_model))
900112.0
900112.0

© 2018 Max Ehrlich