Optimization-based meta-learning: Using MAML with PyTorch on the MNIST dataset

In this tutorial, we continue looking at MAML optimization methods with the MNIST dataset.

2 months ago   •   10 min read

By Adrien Payong

Meta learning, which is also referred to as learning to learn has become an area of research, in the field of machine learning. Its objective is to provide models, with the capability to swiftly adapt to tasks or domains when there is limited data available. One notable algorithm used in meta learning is known as Model Agnostic Meta Learning (MAML).

Model-Agnostic Meta-Learning, or MAML, is one such method that goes hand in hand with optimization-based meta-learning. It is an algorithm proposed by Chelsea Finn, et al. from UC Berkeley. The unique aspect of MAML is its model-agnosticism; it is compatible with any model that is trainable with gradient descent, including but not limited to convolutional and recurrent networks.

It has an inner and outer layer that it uses to function. Gradient descent is used on individual tasks to update the model's parameters in the inner layer, allowing for rapid task-specific adaptation. The main goal of the outer level is to learn new tasks quickly and efficiently. It is dedicated to identifying the best possible initialization for this purpose.

Practical Example: Few-shot Image Classification

Let's look at the real-world application of few-shot image classification to see the power of MAML in action. Consider a dataset where there are few images annotated with the desired labels. With such little data, traditional machine learning algorithms often fail to provide optimal outcomes. But this is where MAML steps in to help:

Inner level

The inner level of meta-learning in the context of MAML (Model-Agnostic Meta-Learning) or generally in meta-learning refers to how a model is modified for a specific task during the meta-training phase. This adaptation occurs on each individual task encountered during the meta-training process and involves a few key steps:

  • Initialization: At the beginning of each task, the model is initialized with the meta-learned parameters obtained from the outer level of meta-training. The initial models are those that have shown their ability to perform well in different tasks.
  • Task Specific Training: The model is then trained on this particular task using limited amount of task specific data. This stage usually takes a short time and aims at adjusting the model’s parameters so as to be more aligned with current data set features.
  • Gradient Calculation: Gradients for parameter adjustment are computed by back propagating error through training process conducted on each respective task. After task specific training, these gradients are computed by back propagating error through it.
  • Parameter Update: The model's parameters are updated in the opposite direction of the calculated gradients.

Outer Level

The meta-learning process is controlled by the outermost layer of Model-Agnostic Meta-Learning (MAML). In MAML, meta-learning goes over a distribution of tasks, and the outer loop entails updating the model’s parameters on the basis of how it performs across various tasks. The main activities at the outer level of MAML are as follows:


  • Initialize the model parameters randomly or using some pretrained values.

Meta-Training Loop:

  • For each iteration in the meta-training loop, sample a batch of tasks from the task distribution.
  • For each task in that batch, perform an internal loop (task-specific training) to make the model best suited for every given task.
  • Compute specific loss for each task by validating adapted model against validation set.


  • Calculate the gradient of the average task-specific loss across all tasks in the batch with respect to the initial model parameters.
  • Update the model parameters in the opposite direction of these gradients to encourage the model to learn a set of parameters that are more adaptable to a wide range of tasks.

The goal is to adjust those initialization parameters, so that the model can learn faster when it sees new tasks. It's like the model is learning how to learn and the outer loop lets it get better at adapting quickly.

The mathematical formula for MAML

The mathematical formula for MAML can be expressed as follows:

Given a set of tasks T = {T1, T2, ..., TN}, where each task Ti has a training set Di, MAML aims to find a set of parameters θ that can be quickly adapted to new tasks.

  1. Initialization: Initialize the model parameters θ randomly or with pre-trained weights.
  2. Inner loop: For each task Ti, compute the adapted parameters θi by taking a few gradient steps on the loss function L(Di, θ) using the training data Di.
  3. Outer loop: Update the initial parameters θ by taking the gradient descent step on the meta-objective J(T, θ) over all tasks. This objective measures the performance of the adapted parameters θi on the validation set for each task. Different meta-objectives can be used, such as minimizing the average loss or maximizing the accuracy across tasks.
  4. Repeat steps 2 and 3 for a few iterations to refine the initial parameters.

MAML with PyTorch and MNIST dataset

Here, we'll demonstrate how to put MAML to use using PyTorch and the MNIST dataset. The MNIST dataset consists of grayscales images of handwritten numbers 0-9 that measure 28x28 pixels each. The objective is to train the model to classify the numbers correctly. In the case of MAML, we first initialize a model, often a simple convolutional neural network when dealing with image data. We then simulate a learning process on a variety of tasks, each task being to recognize a specific digit from 0 to 9.

For each task, we calculate the loss and gradients and update the model parameters. After simulating the learning process for a batch of tasks, we then calculate the meta-gradient, which is the average of the gradients calculated for each task. The model parameters are then updated using this meta-gradient. This process is repeated until the model's performance satisfies the desired criteria. The beauty of MAML lies in its ability to adapt to new tasks with just a few gradient updates, making it an excellent choice for tasks like MNIST where the model needs to adapt to recognizing each of the 10 digits.

Step 1: Import Libraries and Load Data

We need to load the MNIST dataset and import any essential libraries. The data will be loaded in batches through the usage of the PyTorch DataLoader.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# Load the MNIST dataset
train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

Step 2: Define the Model

The next step is to settle on a framework for MAML. The CNN we'll be using consists of only two convolutional layers, two max pooling layers, and two fully connected layers.

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # First convolutional layer
        self.relu1 = nn.ReLU()  # ReLU activation function
        self.pool1 = nn.MaxPool2d(kernel_size=2)  # Max pooling layer
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)  # Second convolutional layer
        self.relu2 = nn.ReLU()  # ReLU activation function
        self.pool2 = nn.MaxPool2d(kernel_size=2)  # Max pooling layer
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # First fully connected layer
        self.relu3 = nn.ReLU()  # ReLU activation function
        self.fc2 = nn.Linear(128, 10)  # Second fully connected layer
        self.softmax = nn.Softmax(dim=1)  # Softmax activation function

    def forward(self, x):
        x = self.conv1(x)  # Convolutional layer
        x = self.relu1(x)  # ReLU activation
        x = self.pool1(x)  # Max pooling
        x = self.conv2(x)  # Convolutional layer
        x = self.relu2(x)  # ReLU activation
        x = self.pool2(x)  # Max pooling
        x = x.view(-1, 64 * 5 * 5)  # Reshape the tensor
        x = self.fc1(x)  # Fully connected layer
        x = self.relu3(x)  # ReLU activation
        x = self.fc2(x)  # Fully connected layer
        x = self.softmax(x)  # Softmax activation
        return x

Building a convolutional neural net for image classification can get a bit complicated. But let's walk through it step-by-step.

  • First, we'll define our CNN class. The init method will set up the layers and we start with a convolutional layer to extract features from the input images. Then a ReLU activation to introduce non-linearity. Next we do some max pooling to reduce dimensions.
  • We repeat this pattern - convolution, ReLU, pooling - for a second layer. This extracts higher level features built on top of the first layer outputs.
  • After the convolutional layers, we flatten the tensor before passing it to a fully connected layer to reduce down to the number of output classes. We use ReLU again here and a second fully connected layer to get the final outputs.
  • The forward pass chains everything together - the two sets of convolutional/ReLU/pooling layers extract features from the input. Then the fully connected layers classify based on those features.
  • We end with a softmax to convert the outputs into normalized probability scores representing each class. This picks the highest scoring class as the model's predicted label.

So, that is a basic CNN architecture for image classification. The key is stacking those convolutional and pooling layers to build up hierarchical feature representations. This lets the fully connected layers efficiently learn the weights to transform those features into accurate predictions.

Step 3: Initialize the Model and define the loss function and the optimizer

# Initialize the model
model = CNN()

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

First, we set up the model. We used our basic CNN for this example. Nothing too fancy, just getting the architecture initialized. Then we define how you're going to train it. Cross entropy loss is pretty standard for classification tasks like what we're doing here. And SGD as the optimizer, with a small learning rate.

Step 4: Define the inner and outer optimization loop

# Inner optimization loop
def inner_loop(task_data):
    for data, labels in task_data:
        outputs = model(data)
        loss = loss_fn(outputs, labels)

# Outer optimization loop
def outer_loop(meta_data):
    for task_data in meta_data:
  • Now, we can define the inner loop where the real optimization happens. This loops through the data for each task, zeroing out the gradients, making predictions, calculating the loss, backpropagating and updating the model parameters. The key thing is that, it's only seeing the data for that specific task in this inner loop.
  • The outer loop is what controls the meta-learning aspect. It loop through and call the inner loop for each of the tasks in the meta-training set. So the model gets updated on task 1, task 2, etc - basically simulating those quick adaptation steps you see in few-shot learning.

So in summary, you get the optimization on each task with the inner loop and then the outer loop controls the meta-optimization over the distribution of tasks. Pretty clever way to leverage SGD for meta-learning! You can tweak the loops and training procedure, but this is the core logic behind optimization-based approaches like MAML.

Step 5: Train the loop

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
  • The training loop's task is to go through all the epochs and handle the training process. The loop epoch variable represents the current epoch number, starting at 0 and counting up to the total epochs minus 1.
  • Inside the loop, it calls the outer_loop function.
  • The train_loader is a data loader object that provides batches of training data to the loop each time through.

Overall, the loop goes epoch by epoch calling the training function and getting new batches of data to train on for each epoch. It handles driving the entire training process.

Step 5: Evaluation of the trained model on a new Task or domain

In order to evaluate a model for a new task, one must first create a new DataLoader, settle the model into evaluation mode, iterate through the new data, determine accuracy, and print the results.

# Create a new DataLoader for the new task or domain
new_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True)
new_loader = DataLoader(new_dataset, batch_size=32, shuffle=False)

# Set the model to evaluation mode

# Initialize variables for evaluation
total_samples = 0
correct_predictions = 0

# Iterate over the new data and perform evaluation
with torch.no_grad():
    for data, labels in new_loader:
        outputs = model(data) # Forward pass through the model
        _, predicted = torch.max(outputs.data, 1) # Get the predicted labels
        total_samples += labels.size(0) # Accumulate the total number of samples
        correct_predictions += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = 100 * correct_predictions / total_samples

# Print the accuracy
print(f"Accuracy on the new task or domain: {accuracy:.2f}%")

The model we trained got 83% accuracy on the new task using the MNIST dataset. That sounds pretty good, but you still have to think about what exactly you want the model to perform. 83% good enough for the app ? If it's for something really important, then 83% might not be enough, and you will need to improve it.

This is a basic implementation of MAML. In an actual scenario, you would use a much more complex model, and you would have to fine-tune the hyperparameters for optimal performance. The number of epochs, the learning rate, the batch size, and the architecture of the model itself are all hyperparameters that can be tweaked to increase performance. For this tutorial, I made the decision to use a simple model and basic hyperparameters for simplicity and readability.

Some variants of MAML

Different variants of MAML and related algorithms provide alternate approaches to meta-learning and few-shot learning. They tackle various weaknesses and challenges of the original MAML method, offering new solutions for efficient and effective meta-learning.

  • Reptile: Reptile is like FOMAML, using per-task gradient descent to adapt the model to new tasks.
  • iMAML: iMAML avoids computing second-order derivatives, reducing complexity through implicit differentiation for gradients.
  • Meta-SGD: Meta-SGD is a meta-learning algorithm that learns to optimize the learning rate of the base learner. It uses a meta-learner to learn the optimal learning rate for each task.
  • anil: anil uses just a single inner loop update, decreasing MAML's computation by skipping multiple updates.
  • Proto-MAML: Proto-MAML takes a prototype-based approach, learning a prototype per class to classify new examples.


MAML being model-agnostic can be used with different models that can be trained via gradient descent like convolutional and recurrent networks. It has an inner layer that operates through both upward and downward directions, where gradients descend on the specific task basis for swift task-driven adaptation. Its outer layer seeks proper initialization which allow it to learn new tasks efficiently.

One good example of such an effectiveness of MAML was demonstrated in few-shot image classification. Traditional machine learning algorithms may fall short in scenarios where only a few annotated images are available. MAML achieves superiority by rapidly changing its model based on the particular tasks during the meta-training step.

The inner level of meta-learning involves initialization, task-specific training using limited data, gradient calculation through backpropagation, and parameter updates. In addition, there are also initialization parameters for the outer level that controls meta-learning process including initializing model parameters, performing a meta-training loop over a task distribution, calculating meta-updates from losses associated with particular tasks and adjusting initialization parameters so as to enhance adaptability.

The mathematical formulation of MAML involves finding a set of parameters that can be swiftly adapted to new tasks. In this case, the inner loop modifies the model for each individual task while the outer loop updates and improves initial parameters depending on how well it performs multiple tasks.

A real-world implementation of MAML using PyTorch and the MNIST dataset is provided. The step-by-step process includes importing libraries, defining the model architecture, initializing the model, setting up inner and outer optimization loops, and training the model.

The last step should involve testing the trained model on a new task or domain. The accuracy on the new task is determined by creating a new DataLoader, setting the model to evaluation mode, iterating through the new data and calculating accuracy. Several variants of MAML, such as Reptile, iMAML, Meta-SGD, anil, and Proto-MAML, offer alternative approaches to address different challenges and weaknesses in meta-learning.

Spread the word

Keep reading