Skip to content
main
Go to file
Code

README.md

FedJAX: Federated learning with JAX

NOTE: FedJAX is still in the early stages and the API will likely continue to change.

What is FedJAX?

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

FedJAX is built around the common core components needed in the FL setting:

  • Federated datasets: Clients and a dataset for each client
  • Models: CNN, ResNet, etc.
  • Optimizers: SGD, Momentum, etc.
  • Federated algorithms: Client updates and server aggregation

For Models and Optimizers, FedJAX provides lightweight wrappers and containers that can work with a variety of existing implementations (e.g. a model wrapper that can support both Haiku and Stax). Similarly, for Federated datasets, TFF provides a well established API for working with federated datasets, and FedJAX just provides utilties for converting to NumPy input acceptable to JAX.

However, what FL researchers will find most useful is the collection and customizability of Federated algorithms provided out of box by FedJAX.

Quickstart

Take a look at the simple Federated Averaging implementation to get a sense of how to write FL algorithms with FedJAX. For running a FL simulation (setting up the federated dataset, model, etc.), see the full EMNIST example.

Useful pointers

NOTE: This is not an officially supported Google product.

About

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

Topics

Resources

License

Releases

No releases published

Packages

No packages published

Languages

You can’t perform that action at this time.