Skip to yearly menu bar Skip to main content

Workshop: Workshop of Graph Neural Networks and Systems (GNNSys'21)

Keynote Talk: High Performance GNNs in JAX by Jonathan Godwin (DeepMind)


Jraph (pronounced "giraffe") is a lightweight library for working with graph neural networks in JAX. It provides a data structure for graphs, a set of utilities for working with graphs, and a 'zoo' of forkable graph neural network models. In this talk we’ll cover the basics of Jraph, XLA and of graph nets, including how we manage padding for graphs with dynamic edge and node shapes. Then we’ll discuss how JAX makes it easier for us to write new kinds of graph neural networks with interesting applications in scientific domains such as simulation. Finally we’ll also cover how we can straightforwardly use jax to shard a graph net across multiple devices, allowing training on graphs with millions (or billions) of edges.