Graph Neural Networks: Fundamentals, Implementation, and Practical Uses

In this tutorial, we introduce the fundamentals of Graph Neural Networks, and demonstrate how to use them in a Gradient Notebook with Python code to build a custom GNN.

10 months ago   •   9 min read

By Adrien Payong

Sign up FREE

Build & scale AI models on low-cost cloud GPUs.

Get started Talk to an expert
Table of contents

Add speed and simplicity to your Machine Learning workflow today

Get startedTalk to an expert

Graph Neural Networks (GNNs) are a type of neural network designed to process information in graph format. They have been used to solve issues in many different fields, and their popularity has grown in recent years as a result of their capacity to deal with complex data structures. In this post, we will discuss the fundamentals of GNNs, including their basic concepts, development, and practical uses. A working example of a GNN built using the PyTorch library will also be provided.

What are Graph Neural Networks?

Graph Neural Networks, or GNNs for short, are a pretty neat type of neural net that can work with data that's structured like graphs. Graphs are basically a bunch of objects, represented as nodes and the relationships between those objects, represented as edges connecting the nodes and GNNs can handle both directed graphs, where the edges have a direction, and undirected graphs where the edges don't have a specific direction. These graphs can vary a lot in size and shape too.

The architecture of a GNN has multiple layers, each taking information from the previous layer. We feed the GNN with a graph that is represented as a set of nodes and edges, along with their associated features. What we get out is a set of nodes embeddings for each node in the input graph. These embeddings represent the features the network learned for each node.

Instead of just operating on vectors, matrices or tensors like a normal neural network, GNNs can work with data structured as full-on graphs. That makes them really flexible for working with networked data, like social networks, molecular structures or transportation systems. The math involved is complex, but the high-level idea is that they iterate through the graph passing messages between nodes to learn useful representations.

How do Graph Neural Networks work?

Graph neural networks, or GNNs for short, are all about learning patterns between nodes in a network. The main idea is that, each node passes messages to its neighboring nodes, sharing information about itself. The nodes then aggregate these messages to build up a rich understanding of the network structure.

it works like this -each node computes a message to send to its neighbors based on its own features and the features of its neighbors. Of course those nodes are performing the same thing, passing messages of their own.

When a node receives messages, it updates its internal state by essentially aggregating them together. This allows information to propagate through the network node by node. As the messages pass back and forth, nodes gain a wider view of the patterns in the graph beyond just their immediate neighborhood.

By stacking multiple layers that repeat this message passing process, GNNs can capture complex relationships and feature representations. The patterns in the graph become more visible to the model with each layer.

Implementing a Graph Neural Network in PyTorch

Cora dataset

The Cora dataset is a popular benchmark used by researchers working on graph representation learning. This dataset includes a bunch of scientific publications divided into seven categories like "CaseBased," "GeneticAlgorithms," "NeuralNetworks," "ProbabilisticMethods," "ReinforcementLearning," "RuleLearning.

The Cora dataset has been around for a while and continues to be a go-to for many projects in this space. It offers a way to test how well your model can analyze both the textual content of documents and the interconnected network of citations between them. Many cool graph neural net papers have used Cora to measure performance on those dual tasks.

It's built as a graph with the publications as nodes and citations between them as edges connecting the nodes. Each document is associated with a feature vector representing its content. The challenge here is to develop a model that can look at the citation graph, the content vectors, and the relationships between them in order to predict which of the seven classes any given publication belongs to.

Data Preprocessing

We install the PyTorch Geometric library with the command: pip install torch_geometric. We can then use the PyTorch Geometric library to load and preprocess the dataset.

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

dataset = Planetoid(root='data/Cora', name='Cora', transform=T.NormalizeFeatures())

The Planetoid class loads up the Cora dataset and normalizes the feature vectors. We can get the preprocessed data using dataset, which gives us a Data object with these attributes:

  • x: a matrix of node features of shape '(num_nodes, num_features )'
  • edge_index: edge connectitivity matrix of shape '(2, num_edges)'
  • y: a vector of node labels of shape '(num_nodes)'
  • train_mask, val_mask, test_mask: boolean masks showing which nodes are for training, validating and testing.

Model Architecture

When building a graph neural network, choosing the right model architecture is super important. We will walk through a basic implementation using PyTorch's torch_geometric library. Well use a graph convolutional network which is a solid starting point for a lot of different graph learning tasks.

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        # Define the first graph convolutional layer
        self.conv1 = GCNConv(in_channels, hidden_channels)
        # Define the second graph convolutional layer
        self.conv2 = GCNConv(hidden_channels, out_channels)
        # Define the linear layer
        self.linear = torch.nn.Linear(out_channels, out_channels)

    def forward(self, x, edge_index):
        # Apply the first graph convolutional layer
        x = self.conv1(x, edge_index)
        # Apply the ReLU activation function
        x = F.relu(x)
        # Apply the second graph convolutional layer
        x = self.conv2(x, edge_index)
        # Apply the ReLU activation function
        x = F.relu(x)
        # Apply the linear layer
        x = self.linear(x)
        # Apply the log softmax activation function
        return F.log_softmax(x, dim=1)
  • In the above code, we imported torch and torch. nn. functional to get access to some useful neural net modules and functions. Then, we defined a GNN class inheriting from torch and nn. Module.
  • In the init method, we defined two convolutional layers using the GCNConv module from PyTorch Geometric. This allows to easily implement graph convolutions. We have added a simple linear layer.
  • The forward pass first passes the input through the two conv layers, each time applying ReLU activation. Then it goes through the linear layer and finally log softmax to squash the outputs.

In a few lines of code, we can build a nice little graph neural network! Obviously this is a simple example, but we can see how PyTorch and PyTorch Geometric let us quickly prototype and iterate on graph neural net architectures. The GCNConv layers make it very easy to incorporate graph structure into our models.

Training

For training we'll use cross-entropy loss and the Adam optimizer. We can split up the data into training, validation, and test sets using those mask attributes on the Data object.

# Set the device to CUDA if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define the GNN model with the specified input, hidden, and output dimensions, and move it to the device
model = GNN(dataset.num_features, 16, dataset.num_classes).to(device)
# Define the Adam optimizer with the specified learning rate and weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Define the training function
def train():
    # Set the model to training mode
    model.train()
    # Zero the gradients of the optimizer
    optimizer.zero_grad()
    # Perform a forward pass of the model on the training nodes
    out = model(dataset.x.to(device), dataset.edge_index.to(device))
    # Compute the negative log-likelihood loss on the training nodes
    loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask])
    # Compute the gradients of the loss with respect to the model parameters
    loss.backward()
    # Update the model parameters using the optimizer
    optimizer.step()
    # Return the loss as a scalar value
    return loss.item()

# Define the testing function
@torch.no_grad()
def test():
    # Set the model to evaluation mode
    model.eval()
    # Perform a forward pass of the model on all nodes
    out = model(dataset.x.to(device), dataset.edge_index.to(device))
    # Compute the predicted labels by taking the argmax of the output scores
    pred = out.argmax(dim=1)
    # Compute the training, validation, and testing accuracies
    train_acc = pred[dataset.train_mask].eq(dataset.y[dataset.train_mask]).sum().item() / dataset.train_mask.sum().item()
    val_acc = pred[dataset.val_mask].eq(dataset.y[dataset.val_mask]).sum().item() / dataset.val_mask.sum().item()
    test_acc = pred[dataset.test_mask].eq(dataset.y[dataset.test_mask]).sum().item() / dataset.test_mask.sum().item()
    # Return the accuracies as a tuple
    return train_acc, val_acc, test_acc

# Train the model for 500 epochs
for epoch in range(1, 500):
    # Perform a single training iteration and get the loss
    loss = train()
    # Evaluate the model on the training, validation, and testing sets and get the accuracies
    train_acc, val_acc, test_acc = test()
    # Print the epoch number, loss, and accuracies
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

The train function does one round of training and returns the loss. The test function checks how the model's performing on the training, validation and test sets and gives back the accuracies. We trained the model for 500 epochs and print the training and testing accuracies at each epoch.

Compute the accuracy of the GNN model

This code below defines a function to calculate how accurate the model is on the entire dataset. The compute_accuracy() function switches the model to evaluation mode, performs forward pass, and predicts labels for each node. It compares these predicted labels with the ground truth labels and calculates the number of correct predictions. Then it divides the number of correct predictions by the total number of nodes in the dataset to get the accuracy percentage.

@torch.no_grad()
def compute_accuracy():
    model.eval()
    out = model(dataset.x.to(device), dataset.edge_index.to(device))
    pred = out.argmax(dim=1)
    correct = pred.eq(dataset.y.to(device)).sum().item()
    total = dataset.y.shape[0]
    accuracy = correct / total
    return accuracy

accuracy = compute_accuracy()
print(f"Accuracy: {accuracy:.4f}")

In this case, the model's accuracy on the Cora dataset was 0.8006. This means that about 80% of the time, the model was able to correctly predict the class label. That's pretty good, but not perfect. Accuracy gives us a quick high-level view of how well the model is performing overall. But you have to dig deeper to really understand where it's succeeding and where its struggling. To gain a deeper understanding of the model's effectiveness, it is recommended to consider other evaluation metrics such as precision, recall, F1 score, and confusion matrix. These metrics provide insights into the model's performance on different aspects, such as correctly identifying positive and negative cases and handling imbalanced datasets.

So while 80% accuracy is solid, we'd want more context before declaring this model a smashing success. The accuracy metric alone doesn't give the full picture of whats going on under the hood. But it's a good starting point for gauging performance.

Evaluation

We can evaluate how the GNN's performing using stuff like accuracy, precision, recall, F1 score. But, we can also visualize the node embeddings the model learns using t-SNE. It takes the high-dimensional embeddings and projects them down into 2D, so we can actually visualize them.

# Import the necessary libraries
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Set the model to evaluation mode
model.eval()

# Perform a forward pass of the model on the dataset
out = model(dataset.x.to(device), dataset.edge_index.to(device))

# Apply t-SNE to the output feature matrix to obtain a 2D embedding
emb = TSNE(n_components=2).fit_transform(out.cpu().detach().numpy())

# Create a figure with a specified size
plt.figure(figsize=(10,10))

# Create a scatter plot of the embeddings, color-coded by the true labels
plt.scatter(emb[:,0], emb[:,1], c=dataset.y, cmap='jet')

# Display the plot
plt.show()

Note: The reader can run the code above. It will display a scatter plot that can be interpreted.

The code uses t-SNE to show the learned node embeddings in a 2D scatter plot, which is a smart way to visualize high-dimensional data. Le us walk through what's going on:

  • Each point in the plot represents a node in the dataset. The x and y axes are the two dimensions that t-SNE squeezed the embeddings into. The color of each point represents the true label of the corresponding node in the datase
  • Nodes that have similar embeddings should have similar labels, so they'll cluster together on the plot. On the flip side, nodes with very different embeddings will probably have different labels so they'll be farther apart.
  • Overall, the plot gives you a nice picture of the relationships between nodes based on their learned embeddings. You can see groups forming that must share some underlying similarity. Its a handy way to peek inside the model and understand how it's organizing concepts.

Potential Challenges and Considerations

  • With 2708 nodes and 5429 edges, the Cora dataset is considered to be on the smaller side. This might hinder the GNN's efficiency, necessitating the adoption of more advanced methods like data augmentation and transfer learning.
  • There is one type of node and one type of edge in the Cora dataset, making it a homogenous network. This might restrict the GNN's usefulness when used for more complex networks including different node and edge types.
  • Selecting appropriate values for hyperparameters such as the number of hidden layers, the number of hidden units, and the learning rate may significantly affect the performance of the GNN and require careful tuning.

Conclusion

In this article, we explored the fundamentals of Graph Neural Networks (GNNs) and their application in various fields. GNNs are a powerful type of neural network designed to process graph-structured data, making them suitable for tasks involving complex data structures such as social networks, molecular structures, and transportation systems.

We tried using one of these graph networks in PyTorch to look at a dataset of science puplications and figure out what category they're in. There's one dataset called Cora that can be used to test out graph learning methods. It's got the publications as nodes and citations between publications as edges connecting them. Our goal was to have the network look at the contents in each publication and citation relationship to predict the category.

We preprocessed the Cora dataset using the PyTorch Geometric library. We normalized the feature vectors for each publication and split it up into sets for training, validating and testing the model. We defined the GNN model architecture using graph convolutional layers and a linear layer and we trained it by minimizing the cross-entropy loss using Adam optimizer. We compute the accuracy of our model

There is definitely more the reader could do to improve graph networks like this, but this project gave us a good taste of how powerful GNNs can be on complex relational data.

Add speed and simplicity to your Machine Learning workflow today

Get startedTalk to an expert

Spread the word

Keep reading