year:
paper: https://arxiv.org/pdf/2504.08339
website:
code:
connections: NEAT, jax


TensorNEAT Library Overview

Table of Contents

TLDR

TensorNEAT is a JAX-based library for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, designed to leverage GPU acceleration for evolving neural network structures. It enables parallel processing of networks with different architectures through tensorization, achieving up to 500x speedup compared to CPU implementations. Built on JAX with support for Brax/Gymnax RL environments, the library implements a functional programming paradigm with immutable state management.

Core Features & Quick Start

Core Features

  • GPU-accelerated NEAT and HyperNEAT algorithms
  • Batch inference across heterogeneous network architectures
  • Network visualization and symbolic representation (LaTeX/Python code generation)
  • Integration with EvoX for distributed computing
  • Support for RL tasks (Brax, Gymnax, MuJoCo)
  • Function fitting and symbolic regression capabilities

Installation

# Install JAX (CPU version)
pip install -U jax
 
# Or for NVIDIA GPUs
pip install -U "jax[cuda12]"
 
# Install TensorNEAT
pip install git+https://github.com/EMI-Group/tensorneat.git

Basic Usage - XOR Problem

from tensorneat.pipeline import Pipeline
from tensorneat import algorithm, genome, problem
from tensorneat.common import ACT
 
# Configure NEAT algorithm
algorithm = algorithm.NEAT(
	pop_size=10000,
	species_size=20,
	survival_threshold=0.01,
	genome=genome.DefaultGenome(
		num_inputs=3,
		num_outputs=1,
		output_transform=ACT.sigmoid,
	),
)
 
# Define problem and pipeline
problem = problem.XOR3d()
pipeline = Pipeline(
	algorithm,
	problem,
	generation_limit=200,
	fitness_target=-1e-6,
	seed=42,
)
 
# Run evolution
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

RL Task Example - Brax Hopper

from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import ACT, AGG
 
pipeline = Pipeline(
	algorithm=NEAT(
		pop_size=1000,
		species_size=20,
		survival_threshold=0.1,
		genome=DefaultGenome(
			num_inputs=11,
			num_outputs=3,
			node_gene=BiasNode(
				activation_options=ACT.tanh,
				aggregation_options=AGG.sum,
			),
			output_transform=ACT.tanh,
		),
	),
	problem=BraxEnv(env_name="hopper", max_step=1000),
	generation_limit=100,
	fitness_target=5000,
)
 
state = pipeline.setup()
state, best = pipeline.auto_run(state)

File Structure

src/tensorneat/
├── __init__.py		   # Main package exports [Pipeline, algorithm, genome, problem, common]
├── pipeline.py		   # Core Pipeline class - orchestrates evolution workflow, handles JIT compilation
├── algorithm/			# Evolution algorithms
│   ├── base.py		  # BaseAlgorithm abstract class defining algorithm interface
│   ├── neat/
│   │   ├── neat.py	  # NEAT implementation with population management, crossover/mutation
│   │   └── species.py   # SpeciesController for speciation, fitness sharing, stagnation
│   └── hyperneat/
│	   ├── hyperneat.py # HyperNEAT algorithm using CPPN to generate substrate weights
│	   └── substrate/   # Substrate patterns [BaseSubstrate, FullSubstrate, MLPSubstrate]
├── genome/			  # Network genome representation
│   ├── base.py		  # BaseGenome abstract class with transform/forward methods
│   ├── default.py	   # DefaultGenome - feedforward networks with topological sort
│   ├── recurrent.py	 # RecurrentGenome - supports recurrent connections
│   ├── gene/			# Gene components
│   │   ├── node/		# Node genes [DefaultNode, BiasNode, OriginNode]
│   │   └── conn/		# Connection genes [DefaultConn, OriginConn]
│   └── operations/	  # Genetic operations
│	   ├── mutation/	# Mutation strategies [DefaultMutation, RecurrentMutation]
│	   ├── crossover/   # Crossover strategies [DefaultCrossover]
│	   └── distance/	# Distance metrics [DefaultDistance]
├── problem/			 # Problem definitions
│   ├── base.py		  # BaseProblem interface for fitness evaluation
│   ├── func_fit/		# Function fitting problems
│   │   ├── xor.py	   # XOR, XOR3d benchmark problems
│   │   └── custom.py	# CustomFuncFit for user-defined functions
│   └── rl/			  # Reinforcement learning
│	   ├── brax.py	  # Brax physics environments integration
│	   ├── gymnax.py	# Gymnax environments integration
│	   └── rl_jit.py	# Base RL environment with JIT support
└── common/			  # Shared utilities
	├── state.py		 # Immutable State class for functional programming
	├── stateful_class.py # StatefulBaseClass for state management
	├── functions/	   # Activation/aggregation functions
	│   ├── act_jnp.py   # JAX activation functions [sigmoid, tanh, relu, sin, exp, log]
	│   ├── agg_jnp.py   # JAX aggregation functions [sum, product, mean, max, min]
	│   └── manager.py   # FunctionManager for registering custom functions
	├── graph.py		 # Graph operations [topological_sort, find_useful_nodes]
	├── sympy_tools.py   # Symbolic math tools for LaTeX/Python code generation
	└── evox_adaptors/   # EvoX integration
		├── algorithm_adaptor.py  # EvoXAlgorithmAdaptor wrapper
		└── tensorneat_monitor.py # TensorNEATMonitor for EvoX workflows

Key Components

Pipeline

Central orchestrator managing the evolution workflow. Handles algorithm-problem integration, JIT compilation, multi-device support, fitness evaluation, and generation tracking.

class Pipeline:
	def setup(self, state=State()) -> State  # Initialize algorithm and problem
	def auto_run(self, state) -> (State, best_genome)  # Run evolution loop
	def step(self, state) -> (State, pop, fitnesses)  # Single evolution step

Algorithm Interface

Base interface that all algorithms (NEAT, HyperNEAT) implement. Defines the ask-tell paradigm for evolutionary algorithms.

class BaseAlgorithm:
	def ask(self, state) -> (nodes, conns)  # Get population to evaluate
	def tell(self, state, fitness) -> State  # Update with fitness feedback
	def transform(self, state, individual) -> network  # Genome to network
	def forward(self, state, transformed, inputs) -> outputs  # Execute network

Problem Interface

Unified interface for fitness evaluation. All problems implement this contract:

class BaseProblem:
	def evaluate(self, state, randkey, act_func, params) -> fitness_scalar
	@property
	def input_shape(self) -> tuple  # e.g., (observation_dim,)
	@property
	def output_shape(self) -> tuple  # e.g., (action_dim,)
	def show(self, state, randkey, act_func, params)  # Visualize solution

Genome System

Flexible genome representation with customizable genes. Networks are stored as two JAX arrays:

# Nodes array: (max_nodes, node_gene_length)
# Each row: [node_key, activation, aggregation, bias, ...]
nodes = jnp.array([
	[0, ACT.tanh, AGG.sum, 0.5, ...],  # Input node
	[5, ACT.relu, AGG.mean, -0.2, ...], # Hidden node
	...
])
 
# Connections array: (max_conns, conn_gene_length)  
# Each row: [in_key, out_key, weight, enabled, ...]
conns = jnp.array([
	[0, 5, 0.7, 1.0, ...],  # Connection from node 0 to 5
	[5, 3, -0.3, 1.0, ...], # Connection from node 5 to 3
	...
])
 
# Genome operations
genome.initialize(state, randkey)  # Create initial topology
genome.execute_mutation(state, randkey, nodes, conns, ...)
genome.execute_crossover(state, randkey, parent1, parent2)
genome.forward(state, transformed, inputs)  # Network execution

State Management

Functional programming with immutable state. All operations return new state instances:

# State is immutable - operations return new instances
state = State(randkey=key, generation=0)
state = state.register(pop_nodes=nodes, pop_conns=conns)  # Add new keys
state = state.update(generation=1)  # Update existing keys
 
# Access state attributes
print(state.generation)  # 1
print(state.pop_nodes.shape)  # (pop_size, max_nodes, node_length)
 
# State is pytree-compatible for JAX transformations
@jax.jit
def step(state):
	# All state mutations create new instances
	# Enables pure functional programming
	return new_state

Code Examples

Custom Activation Function

from tensorneat.common import ACT
 
# Define custom activation
def square(x):
	return x ** 2
 
# Register for use in evolution
ACT.add_func("square", square)
 
# Use in genome configuration
genome = DefaultGenome(
	num_inputs=2,
	num_outputs=1,
	node_gene=BiasNode(
		activation_options=[ACT.identity, ACT.square, ACT.tanh],
		aggregation_options=[AGG.sum, AGG.product],
	),
)

Accessing Network Data

TensorNEAT provides multiple ways to access and export network representations:

# 1. Raw JAX arrays - direct access to internal representation
nodes, conns = algorithm.ask(state)  # Returns (max_nodes, gene_length) arrays
 
# 2. Dictionary format - structured network data
network = algorithm.genome.network_dict(state, nodes, conns)
# Returns:
# {
#	 "nodes": {0: {"idx": 0, "activation": "tanh", ...}, ...},
#	 "conns": {(0, 5): {"in": 0, "out": 5, "weight": 0.5, ...}, ...},
#	 "topo_order": [0, 1, 2, 5, 3],  # Execution order
#	 "topo_layers": [[0, 1], [2], [5, 3]],  # Layer grouping  
#	 "useful_nodes": [0, 1, 2, 3]  # Nodes affecting outputs
# }
 
# 3. Visualization - exports to NetworkX internally
algorithm.genome.visualize(network, save_path="network.svg")
 
# 4. Symbolic representation - SymPy expressions
sympy_res = algorithm.genome.sympy_func(
	state, network,
	sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
)
latex_code = to_latex_code(*sympy_res)  # LaTeX equations
 
# 5. Executable code generation
python_code = to_python_code(*sympy_res)  # Standalone NumPy/JAX code

The internal representation uses pure JAX arrays with NaN padding, making networks fully compatible with JAX transformations like jit, vmap, pmap, and grad. This enables seamless integration with the broader JAX ecosystem including Optax, Haiku, and Flax.

HyperNEAT Configuration

from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate
 
# Define substrate topology
substrate = FullSubstrate(
	input_coors=((-1, -1), (0, -1), (1, -1)),  # Input layer positions
	hidden_coors=((-1, 0), (0, 0), (1, 0)),	# Hidden layer positions
	output_coors=((0, 1),),					 # Output layer position
)
 
# NEAT for evolving CPPN
neat = NEAT(
	pop_size=10000,
	genome=DefaultGenome(
		num_inputs=4,   # Query coordinates (x1, y1, x2, y2)
		num_outputs=1,  # Connection weight
		output_transform=ACT.tanh,
	),
)
 
# HyperNEAT algorithm
algorithm = HyperNEAT(
	substrate=substrate,
	neat=neat,
	weight_threshold=0.3,
	activation=ACT.tanh,
)

Multi-Device with EvoX

from tensorneat.common.evox_adaptors import EvoXAlgorithmAdaptor
from evox import workflows
 
# Wrap TensorNEAT algorithm for EvoX
evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm)
 
# Create workflow with multi-device support
workflow = workflows.StdWorkflow(
	algorithm=evox_algorithm,
	problem=problem,
	solution_transforms=[jax.vmap(evox_algorithm.transform)],
)
 
# Enable multi-device execution
state = workflow.init(key)
state = workflow.enable_multi_devices(state)
 
# Run distributed evolution
for i in range(generations):
	state = workflow.step(state)

Additional Resources

  • Documentation: Full API docs in docs/source/ with Sphinx configuration
  • Examples: Complete examples in examples/ covering:
    • Function fitting: XOR, custom functions, symbolic regression
    • RL tasks: Brax (ant, hopper, walker2d), Gymnax (cartpole, mountain car)
    • Visualization: Network graphs, symbolic representation
    • EvoX integration: Multi-device walker2d example
  • Tutorials: Jupyter notebooks in tutorials/ covering:
    • Functional programming and state management
    • Genome structure and operations
    • Pipeline usage patterns
  • Tests: Unit tests in test/ for genome operations, crossover/mutation, substrates

Key Papers

  • Original NEAT: Stanley & Miikkulainen (2002)
  • HyperNEAT: Stanley et al. (2009)
  • TensorNEAT: Wang et al. (2025) - GECCO 2024 Best Paper Award

Dependencies

  • JAX >= 0.4.28 (core computation)
  • Brax >= 0.10.3 (physics simulation)
  • Gymnax >= 0.0.8 (JAX-based gym environments)
  • NetworkX >= 3.3 (graph visualization)
  • SymPy >= 1.12.1 (symbolic math)
  • EvoX v0.9.1-dev (optional, for distributed computing)