Guide
Graph neural networks explained
Convolutional networks assume pixels live on a grid. Transformers assume tokens live in a sequence. But much of the world is relational: users follow each other, wallets send payments to wallets, atoms bond into molecules, warehouses ship to stores. A graph neural network (GNN) learns on that structure directly — each node aggregates information from its neighbors through repeated message passing, producing embeddings that capture both local context and multi-hop patterns. GNNs power fraud rings on payment graphs, link prediction in social networks, property prediction on molecules, and retrieval over knowledge graphs. They are not a replacement for every tabular model, but when the relationship is the signal, flattening the graph into rows throws information away. This guide covers graph representations, message-passing mechanics, major architectures (GCN, GraphSAGE, GAT), task types, oversmoothing and scalability limits, and when GNNs beat classical machine learning or dense deep learning on the same data.
Graphs as machine-learning inputs
A graph G = (V, E) has a set of nodes V and edges E connecting them. Each node can carry a feature vector (user age, wallet balance histogram, atom type one-hot). Each edge can carry features too (transaction amount, bond order, follow date). The connectivity is encoded in an adjacency matrix A or, more efficiently, as edge lists and neighbor indices for sparse storage.
Graphs come in several flavors relevant to modeling:
- Homogeneous — one node type and one edge type (social follow graph).
- Heterogeneous — multiple node and edge types (users, items, clicks in a recommender).
- Directed vs undirected — payment flows are directed; mutual friendships are often treated as undirected.
- Static vs dynamic — edges that appear over time need temporal GNNs or snapshot windows.
The key insight: two nodes with identical feature vectors but different neighborhoods should receive different embeddings after message passing. A tabular model that ignores A cannot distinguish them. That is the gap GNNs fill.
Message passing: how GNN layers work
Most GNNs follow a three-step pattern at each layer l:
- Message — each edge computes a message from source and destination hidden states.
- Aggregate — each node pools messages from all incoming (or bidirectional) neighbors.
- Update — the node combines its previous state with the aggregated message via an MLP or GRU.
Stacking L layers lets a node indirectly receive information from nodes up to L hops away — analogous to receptive fields in CNNs. After L layers, node embedding hv encodes both its own features and the structure around it.
Three families dominate practice:
- Graph Convolutional Network (GCN) — spectral-inspired normalized aggregation: neighbors contribute weighted sums of their features. Simple, fast, strong baseline on citation and social graphs.
- GraphSAGE — samples a fixed number of neighbors per layer instead of using all of them. Supports inductive learning: embed nodes not seen at training time by applying the same aggregation function. Critical for billion-edge graphs.
- Graph Attention Network (GAT) — learns attention weights over neighbors so some edges matter more than others. Useful when neighbor importance is heterogeneous (fraud: one bad counterparty vs many benign ones).
Heterogeneous graphs often use relational GNNs (R-GCN, HGT) with separate message functions per edge type — the graph analogue of multi-head attention across relation types.
Task types: node, edge, and graph level
GNNs support predictions at three granularities:
| Task level | Question | Example |
|---|---|---|
| Node classification | What label does this node have? | Is this wallet fraudulent? What topic is this paper? |
| Edge prediction | Will / should this edge exist? | Recommend a follow, predict a payment, drug–target interaction |
| Graph classification | What label does the whole graph have? | Is this molecule toxic? Is this code dependency graph malicious? |
For node classification, you run message passing and apply a classifier head on each node's final embedding. For graph classification, you add a readout step — sum, mean, or attention pooling over all node embeddings — to produce a single graph vector before the classifier.
Link prediction often scores candidate edges by combining embeddings of the two endpoints (dot product, MLP on concatenation) — closely related to recommendation system two-tower models, but with embeddings shaped by graph structure rather than only ID lookups.
Where GNNs shine — and where they do not
Strong fits:
- Fraud and abuse — ring detection on transaction graphs where colluders share few direct features but dense mutual transfers.
- Knowledge graphs — entity typing, link prediction, and question answering over structured triples.
- Chemistry and materials — molecular property prediction where atoms are nodes and bonds are edges.
- Supply chain and logistics — delay propagation, anomaly detection on shipment networks.
- Social and content graphs — community detection, influence estimation, semi-supervised label spread.
Weak fits:
- Data with no meaningful edges — use tabular boosting or MLPs instead.
- Graphs so dense that message passing is just expensive matrix multiply with little structural bias.
- Tasks where hand-crafted neighbor statistics (degree, PageRank, clustering coefficient) already capture the signal — always benchmark a strong XGBoost baseline on node features plus graph metrics.
GNNs complement, rather than replace, anomaly detection pipelines: use GNN embeddings as features for downstream detectors, or train end-to-end when the anomaly is defined by subgraph structure (sudden star pattern of micro-transactions).
Oversmoothing, depth, and scalability
Stacking too many GNN layers causes oversmoothing: node embeddings converge toward similar values because repeated averaging washes out local differences. In practice, 2–4 layers cover most tasks; deeper graphs need residual connections, Jumping Knowledge (concatenate embeddings from all layers), or separate hop encodings.
Scalability challenges appear on web-scale graphs:
- Neighbor sampling (GraphSAGE, FastGCN) limits fan-out per layer.
- Mini-batch training builds subgraphs around seed nodes instead of full-graph forward passes.
- Partitioning shards the graph across workers; cross-partition edges need halo exchanges.
- Feature precomputation — run a shallow GNN offline to produce node embeddings, then serve retrieval from a vector database for latency-sensitive queries.
Inductive vs transductive: transductive models (classic GCN on a fixed graph) embed only nodes present at training time. Inductive models (GraphSAGE, GAT with proper feature inputs) generalize to new nodes — required for live user signup or new wallet addresses.
Training and evaluation discipline
Graph data leaks easily. A random 80/20 node split puts neighbors of test nodes in the training set, inflating metrics. Use:
- Masked training — hide labels on test nodes but keep edges (transductive semi-supervised).
- Inductive splits — hold out entire subgraphs or time slices.
- Edge-level splits for link prediction — never train on edges you test.
Report metrics appropriate to imbalance: fraud graphs are 99% negative — accuracy is misleading; use precision-recall AUC or PR-AUC. For multi-class node tasks, macro-F1 across rare classes. Track calibration if scores feed human review queues.
Feature engineering still matters: combine learned GNN embeddings with static node attributes, temporal aggregates (30-day in-degree), and graph statistics. Hybrid models often beat pure GNN or pure tabular alone — the same lesson as hybrid retrieval in search systems.
Decision table: which approach when?
| Situation | Recommended approach |
|---|---|
| Small static graph, all nodes known at train time | 2-layer GCN, masked label training |
| New nodes arrive continuously | GraphSAGE or GAT with node features, inductive mini-batches |
| Multiple node/edge types (user–item–category) | Heterogeneous GNN (R-GCN, HGT) or metapath-based sampling |
| Billion-edge social or payment graph | Neighbor sampling + offline embedding table + periodic refresh |
| Whole-graph label (molecule, program AST) | GNN with global attention readout or Set2Set pooling |
| Edges are the product (recommendations) | Link prediction head; compare to matrix factorization baseline |
| Graph structure is weak signal | Gradient boosting on node features + degree/PageRank features first |
Anti-patterns to avoid
- Random node splits on connected graphs — metrics look great, production fails.
- Too many layers without residual or JK connections — oversmoothed useless embeddings.
- Ignoring edge direction on directed flows — reverses causal signal.
- Training on test edges for link prediction — trivial memorization.
- No tabular baseline — stakeholders cannot tell if the graph complexity earned its keep.
- Full-graph inference on every request — precompute embeddings; GNN forward passes do not belong in a 50ms API path for million-node graphs.
Production checklist
- Define the graph explicitly: nodes, edge types, direction, feature schema.
- Choose inductive vs transductive training to match deployment (new nodes yes/no).
- Use leakage-safe splits (subgraph, temporal, or masked transductive).
- Benchmark against XGBoost on node features + graph statistics.
- Start with 2-layer GCN or GraphSAGE; add GAT attention only if ablation helps.
- Cap neighbor samples per layer; profile GPU memory on worst-case degree nodes.
- Monitor embedding drift and graph growth; schedule periodic retraining.
- Expose top contributing neighbors for fraud and safety review workflows.
Key takeaways
- GNNs learn on relational data via repeated message passing over edges — not on flattened tables.
- GCN, GraphSAGE, and GAT differ in aggregation, sampling, and attention — pick based on scale and whether new nodes appear at inference.
- Tasks span node, edge, and graph levels; readout pooling bridges node embeddings to whole-graph labels.
- Oversmoothing and leaky evaluation splits are the two most common ways GNN projects fail in production.
- Always compare to a strong non-graph baseline; use GNN embeddings as features when full end-to-end training is too heavy.
Related reading
- Deep learning explained — neural network foundations that GNN update functions build on
- Recommendation systems explained — link prediction and two-tower retrieval on user–item graphs
- Anomaly detection explained — spotting outliers on graph-derived features and live telemetry
- Vector databases explained — serving precomputed GNN node embeddings at retrieval latency