Search This Blog

Thursday, 27 April 2017

MNIST Study in PyTorch

MNIST is the hello world into ML world. MNIST  dataset is a collection of images of numbers 0..9 and labels. The images are of size 28x28. Lets just take few sample from the dataset to get a feel of how it looks like. 

[[ 6.  9.  9.  5.  4.]
 [ 3.  6.  5.  0.  1.]
 [ 8.  1.  3.  6.  2.]
 [ 9.  4.  8.  8.  6.]
 [ 0.  6.  4.  2.  3.]]


You can see that the image of number <> is associated with number <>. It is a list of (image of number, number). As usual we are gonna feed the neural network with image from the left and its label from the right. We will train a simple feed forward network, call it Model0.

class Model0(nn.Module):
    def __init__(self):
        super(Model0, self).__init__()     
        self.output_layer = nn.Linear(28*28, 10)
   
    def forward(self, x):
        x = self.output_layer(x)
        return F.log_softmax(x)
and lets train it

def train(model):
    optimizer = optim.SGD(model.parameters(), 
                          lr=lr, 
                          momentum=momentum)
    model.train()
    for data, target in enumerate(train_loader):
       optimizer.zero_grad()
       data = data.view(batch_size , -1)
       data, target = Variable(data), Variable(target)    
       output = model(data)
 
       loss = F.nll_loss(output, target)
       loss.backward()
       optimizer.step()
 
model = Model0()
train(model)

lets see how our network predicts the images.

[[ 6.  2.  9.  1.  8.]
 [ 5.  6.  5.  7.  5.]
 [ 4.  8.  6.  3.  0.]
 [ 6.  1.  0.  9.  3.]
 [ 7.  2.  8.  4.  4.]]
Most of the predictions look right. Lets run this over entire test dataset.
def test_tuts(model):
    model.eval()
    test_loss, correct = 0, 0

    for data, target in test_loader:
        data = data.view(data.size()[0], -1)
        data, target = Variable(data), Variable(target)
        output = model(data)

        pred          = output.data.max(1)[1] 
        correct   += pred.eq(target.data).cpu().sum()
        test_loss += F.nll_loss(output, target).data[0]

    test_loss = test_loss
    test_loss /= len(test_loader) 
    print(
     'Avg loss: {:.4f}, Accuracy: {}/{}({:.0f}%)'
      .format(
         test_loss,
         correct,
         len(test_loader.dataset),
         100. * correct / len(test_loader.dataset)
      )
    )

   

No comments:

Post a Comment