evosax
evosax: JAX-Based Evolution Strategies 🦎
Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax allows you to leverage JAX, XLA compilation and auto-vectorization/parallelization to scale ES to your favorite accelerators. The API is based on the classical ask, evaluate, tell cycle of ES. Both ask and tell calls are compatible with jit, vmap/pmap and lax.scan. It includes a vast set of both classic (e.g. CMA-ES, Differential Evolution, etc.) and modern neuroevolution (e.g. OpenAI-ES, Augmented RS, etc.) strategies. You can get started here 👉
Basic evosax API Usage 🍲
import jax
from evosax import CMA_ES
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
# Run ask-eval-tell loop - NOTE: By default minimization!
for t in range(num_generations):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = strategy.ask(rng_gen, state, es_params)
fitness = ... # Your population evaluation fct
state = strategy.tell(x, fitness, state, es_params)
# Get best overall population member & its fitness
state.best_member, state.best_fitness
Implemented Evolution Strategies 🦎
Installation ⏳
The latest evosax release can directly be installed from PyPI:
pip install evosax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/evosax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
Examples 📖
- 📓 Classic ES Tasks: API introduction on Rosenbrock function (CMA-ES, Simple GA, etc.).
- 📓 CartPole-Control: OpenES & PEPG on the
CartPole-v1gym task (MLP/LSTM controller). - 📓 MNIST-Classifier: OpenES on MNIST with CNN network.
- 📓 LRateTune-PES: Persistent/Noise-Reuse ES on meta-learning problem as in Vicol et al. (2021).
- 📓 Quadratic-PBT: PBT on toy quadratic problem as in Jaderberg et al. (2017).
- 📓 Restart-Wrappers: Custom restart wrappers as e.g. used in (B)IPOP-CMA-ES.
- 📓 Brax Control: Evolve Tanh MLPs on Brax tasks using the
EvoJAXwrapper. - 📓 BBOB Visualizer: Visualize evolution rollouts on 2D fitness landscapes.
Key Features 💵
Strategy Diversity:
evosaximplements more than 30 classical and modern neuroevolution strategies. All of them follow the same simpleask/evalAPI and come with tailored tools such as the ClipUp optimizer, parameter reshaping into PyTrees and fitness shaping (see below).Vectorization/Parallelization of
ask/tellCalls: Bothaskandtellcalls can leveragejit,vmap/pmap. This enables vectorized/parallel rollouts of different evolution strategies.
from evosax.strategies.ars import ARS, EvoParams
# E.g. vectorize over different initial perturbation stds
strategy = ARS(popsize=100, num_dims=20)
es_params = EvoParams(sigma_init=jnp.array([0.1, 0.01, 0.001]), sigma_decay=0.999, ...)
# Specify how to map over ES hyperparameters
map_dict = EvoParams(sigma_init=0, sigma_decay=None, ...)
# Vmap-composed batch initialize, ask and tell functions
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
- Scan Through Evolution Rollouts: You can also
lax.scanthrough entireinit,ask,eval,tellloops for fast compilation of ES loops:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
"""Run evolution ask-eval-tell loop."""
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
def es_step(state_input, tmp):
"""Helper es step to lax.scan through."""
rng, state = state_input
rng, rng_iter = jax.random.split(rng)
x, state = strategy.ask(rng_iter, state, es_params)
fitness = ...
state = strategy.tell(y, fitness, state, es_params)
return [rng, state], fitness[jnp.argmin(fitness)]
_, scan_out = jax.lax.scan(es_step,
[rng, state],
[jnp.zeros(num_steps)])
return jnp.min(scan_out)
- Population Parameter Reshaping: We provide a
ParamaterReshaperwrapper to reshape flat parameter vectors into PyTrees. The wrapper is compatible with JAX neural network libraries such as Flax/Haiku and makes it easier to afterwards evaluate network populations.
from flax import linen as nn
from evosax import ParameterReshaper
class MLP(nn.Module):
num_hidden_units: int
...
@nn.compact
def __call__(self, obs):
...
return ...
network = MLP(64)
net_params = network.init(rng, jnp.zeros(4,), rng)
# Initialize reshaper based on placeholder network shapes
param_reshaper = ParameterReshaper(net_params)
# Get population candidates & reshape into stacked pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
- Flexible Fitness Shaping: By default
evosaxassumes that the fitness objective is to be minimized. If you would like to maximize instead, perform rank centering, z-scoring or add weight regularization you can use theFitnessShaper:
from evosax import FitnessShaper
# Instantiate jittable fitness shaper (e.g. for Open ES)
fit_shaper = FitnessShaper(centered_rank=True,
z_score=False,
weight_decay=0.01,
maximize=True)
# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness)
Additonal Work-In-Progress
Strategy Restart Wrappers: Work-in-progress. You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing es_params. In this case the default parameters will be used.
from evosax import CMA_ES
from evosax.restarts import BIPOP_Restarter
# Define a termination criterion (kwargs - fitness, state, params)
def std_criterion(fitness, state, params):
"""Restart strategy if fitness std across population is small."""
return fitness.std() < 0.001
# Instantiate Base CMA-ES & wrap with BIPOP restarts
# Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name)
strategy = CMA_ES(num_dims, popsize, elite_ratio)
re_strategy = BIPOP_Restarter(
strategy,
stop_criteria=[std_criterion],
strategy_kwargs={"elite_ratio": elite_ratio}
)
state = re_strategy.initialize(rng)
# ask/tell loop - restarts are automatically handled
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = re_strategy.ask(rng_gen, state)
fitness = ... # Your population evaluation fct
state = re_strategy.tell(x, fitness, state)
- **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols.
<div class="pdoc-code codehilite">
<pre><span></span><code><span class="kn">from</span> <span class="nn">evosax.experimental.subpops</span> <span class="kn">import</span> <span class="n">BatchStrategy</span>
<span class="c1"># Instantiates 5 CMA-ES subpops of 20 members</span>
<span class="n">strategy</span> <span class="o">=</span> <span class="n">BatchStrategy</span><span class="p">(</span>
<span class="n">strategy_name</span><span class="o">=</span><span class="s2">"CMA_ES"</span><span class="p">,</span>
<span class="n">num_dims</span><span class="o">=</span><span class="mi">4096</span><span class="p">,</span>
<span class="n">popsize</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
<span class="n">num_subpops</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">strategy_kwargs</span><span class="o">=</span><span class="p">{</span><span class="s2">"elite_ratio"</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">},</span>
<span class="n">communication</span><span class="o">=</span><span class="s2">"best_subpop"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">strategy</span><span class="o">.</span><span class="n">initialize</span><span class="p">(</span><span class="n">rng</span><span class="p">)</span>
<span class="c1"># Ask for evaluation candidates of different subpopulation ES</span>
<span class="n">x</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">strategy</span><span class="o">.</span><span class="n">ask</span><span class="p">(</span><span class="n">rng_iter</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
<span class="n">fitness</span> <span class="o">=</span> <span class="o">...</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">strategy</span><span class="o">.</span><span class="n">tell</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">fitness</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
</code></pre>
</div>
- **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`:
<div class="pdoc-code codehilite">
<pre><span></span><code><span class="kn">from</span> <span class="nn">evosax.experimental.decodings</span> <span class="kn">import</span> <span class="n">RandomDecoder</span><span class="p">,</span> <span class="n">HyperDecoder</span>
<span class="c1"># For arbitrary network architectures / search spaces</span>
<span class="n">num_encoding_dims</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">param_reshaper</span> <span class="o">=</span> <span class="n">RandomDecoder</span><span class="p">(</span><span class="n">num_encoding_dims</span><span class="p">,</span> <span class="n">net_params</span><span class="p">)</span>
<span class="n">x_shaped</span> <span class="o">=</span> <span class="n">param_reshaper</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># For MLP-based models we also support a HyperNetwork en/decoding</span>
<span class="n">reshaper</span> <span class="o">=</span> <span class="n">HyperDecoder</span><span class="p">(</span>
<span class="n">net_params</span><span class="p">,</span>
<span class="n">hypernet_config</span><span class="o">=</span><span class="p">{</span>
<span class="s2">"num_latent_units"</span><span class="p">:</span> <span class="mi">3</span><span class="p">,</span> <span class="c1"># Latent units per module kernel/bias</span>
<span class="s2">"num_hidden_units"</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span> <span class="c1"># Hidden dimensionality of a_i^j embedding</span>
<span class="p">},</span>
<span class="p">)</span>
<span class="n">x_shaped</span> <span class="o">=</span> <span class="n">param_reshaper</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre>
</div>
Resources & Other Great JAX-ES Tools 📝
- 📺 Rob's MLC Research Jam Talk: Small motivation talk at the ML Collective Research Jam.
- 📝 Rob's 02/2021 Blog: Tutorial on CMA-ES & leveraging JAX's primitives.
- 💻 Evojax: JAX-ES library by Google Brain with great rollout wrappers.
- 💻 QDax: Quality-Diversity algorithms in JAX.
Acknowledgements & Citing evosax ✏️
If you use evosax in your research, please cite the following paper:
@article{evosax2022github,
author = {Robert Tjarko Lange},
title = {evosax: JAX-based Evolution Strategies},
journal={arXiv preprint arXiv:2212.04180},
year = {2022},
}
We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.
Development 👷
You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.
Disclaimer ⚠️
This repository contains an independent reimplementation of LES and DES based on the corresponding ICLR 2023 publication (Lange et al., 2023). It is unrelated to Google or DeepMind. The implementation has been tested to roughly reproduce the official results on a range of tasks.
1""" 2.. include:: ../README.md 3""" 4from .strategy import Strategy, EvoState, EvoParams 5from .strategies import ( 6 SimpleGA, 7 SimpleES, 8 CMA_ES, 9 DE, 10 PSO, 11 OpenES, 12 PGPE, 13 PBT, 14 PersistentES, 15 ARS, 16 Sep_CMA_ES, 17 BIPOP_CMA_ES, 18 IPOP_CMA_ES, 19 Full_iAMaLGaM, 20 Indep_iAMaLGaM, 21 MA_ES, 22 LM_MA_ES, 23 RmES, 24 GLD, 25 SimAnneal, 26 SNES, 27 xNES, 28 ESMC, 29 DES, 30 SAMR_GA, 31 GESMR_GA, 32 GuidedES, 33 ASEBO, 34 CR_FM_NES, 35 MR15_GA, 36 RandomSearch, 37 LES, 38 LGA, 39 NoiseReuseES, 40) 41from .core import FitnessShaper, ParameterReshaper 42from .utils import ESLog 43from .networks import NetworkMapper 44from .problems import ProblemMapper 45 46 47Strategies = { 48 "SimpleGA": SimpleGA, 49 "SimpleES": SimpleES, 50 "CMA_ES": CMA_ES, 51 "DE": DE, 52 "PSO": PSO, 53 "OpenES": OpenES, 54 "PGPE": PGPE, 55 "PBT": PBT, 56 "PersistentES": PersistentES, 57 "ARS": ARS, 58 "Sep_CMA_ES": Sep_CMA_ES, 59 "BIPOP_CMA_ES": BIPOP_CMA_ES, 60 "IPOP_CMA_ES": IPOP_CMA_ES, 61 "Full_iAMaLGaM": Full_iAMaLGaM, 62 "Indep_iAMaLGaM": Indep_iAMaLGaM, 63 "MA_ES": MA_ES, 64 "LM_MA_ES": LM_MA_ES, 65 "RmES": RmES, 66 "GLD": GLD, 67 "SimAnneal": SimAnneal, 68 "SNES": SNES, 69 "xNES": xNES, 70 "ESMC": ESMC, 71 "DES": DES, 72 "SAMR_GA": SAMR_GA, 73 "GESMR_GA": GESMR_GA, 74 "GuidedES": GuidedES, 75 "ASEBO": ASEBO, 76 "CR_FM_NES": CR_FM_NES, 77 "MR15_GA": MR15_GA, 78 "RandomSearch": RandomSearch, 79 "LES": LES, 80 "LGA": LGA, 81 "NoiseReuseES": NoiseReuseES, 82} 83 84__all__ = [ 85 "Strategies", 86 "EvoState", 87 "EvoParams", 88 "FitnessShaper", 89 "ParameterReshaper", 90 "ESLog", 91 "NetworkMapper", 92 "ProblemMapper", 93 "Strategy", 94 "SimpleGA", 95 "SimpleES", 96 "CMA_ES", 97 "DE", 98 "PSO", 99 "OpenES", 100 "PGPE", 101 "PBT", 102 "PersistentES", 103 "ARS", 104 "Sep_CMA_ES", 105 "BIPOP_CMA_ES", 106 "IPOP_CMA_ES", 107 "Full_iAMaLGaM", 108 "Indep_iAMaLGaM", 109 "MA_ES", 110 "LM_MA_ES", 111 "RmES", 112 "GLD", 113 "SimAnneal", 114 "SNES", 115 "xNES", 116 "ESMC", 117 "DES", 118 "SAMR_GA", 119 "GESMR_GA", 120 "GuidedES", 121 "ASEBO", 122 "CR_FM_NES", 123 "MR15_GA", 124 "RandomSearch", 125 "LES", 126 "LGA", 127 "NoiseReuseES", 128]
