year:
paper: https://arxiv.org/pdf/2504.08339
website:
code:
connections: NEAT, jax
Comment by Kenneth O. Stanley on discussion about historical markers for nodes vs connections (a divergence of the tensorNEAT implementation from original NEAT):
TLDR h/t claude:
The fundamental difference centers on historical markers (innovation numbers):
- Original NEAT: Places unique historical markers on every connection (edge) in the network
- TensorNEAT/Python-NEAT: Only places historical markers on nodes, using (start_node, end_node) tuples to identify connections
Implications:
- Homology determination: When two genomes independently evolve the same connection (e.g., node 2→3), original NEAT treats them as different genes (disjoint), while TensorNEAT treats them as the same gene (matching)
- Crossover behavior: This affects which genes get inherited during reproduction - matching genes are selected 50/50 from parents, while disjoint genes come from the fitter parent
- Speciation: Distance calculations for species formation are computed differently, potentially creating more inclusive species in TensorNEAT
Performance Impact
- The difference likely has some adverse effect on efficiency, but probably not very large
- The main risk is wasting effort crossing over incompatible individuals from distant lineages
- However, speciation should naturally keep most truly distant genomes apart anyway
- The impact would mainly affect “borderline cases” at the edge of compatibility
- Empirical testing would be needed to determine the actual performance difference
Philosophical Perspective
- From a biological standpoint, genes from very distant lineages shouldn’t be considered “the same” just because they’re structurally similar
- His analogy: “if a fish had a mutation that encodes the same protein as something in the human gut, I would not regard the fish as being more compatible with a human for the purposes of mating”
- He estimates “while there may be some difference in performance, it would likely not be very large”
Resolution:
- WLS2002 implemented
OriginNode
andOriginConn
classes in TensorNEAT that follow the original NEAT paper’s approach- Initial experiments by sopotc showed 10-15% better fitness with the original implementation, though more rigorous testing is needed
TensorNEAT Library Overview
Table of Contents
TLDR
TensorNEAT is a JAX-based library for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, designed to leverage GPU acceleration for evolving neural network structures. It enables parallel processing of networks with different architectures through tensorization, achieving up to 500x speedup compared to CPU implementations. Built on JAX with support for Brax/Gymnax RL environments, the library implements a functional programming paradigm with immutable state management.
Core Features & Quick Start
Core Features
- GPU-accelerated NEAT and HyperNEAT algorithms
- Batch inference across heterogeneous network architectures
- Network visualization and symbolic representation (LaTeX/Python code generation)
- Integration with EvoX for distributed computing
- Support for RL tasks (Brax, Gymnax, MuJoCo)
- Function fitting and symbolic regression capabilities
Installation
# Install JAX (CPU version)
pip install -U jax
# Or for NVIDIA GPUs
pip install -U "jax[cuda12]"
# Install TensorNEAT
pip install git+https://github.com/EMI-Group/tensorneat.git
Basic Usage - XOR Problem
from tensorneat.pipeline import Pipeline
from tensorneat import algorithm, genome, problem
from tensorneat.common import ACT
# Configure NEAT algorithm
algorithm = algorithm.NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=genome.DefaultGenome(
num_inputs=3,
num_outputs=1,
output_transform=ACT.sigmoid,
),
)
# Define problem and pipeline
problem = problem.XOR3d()
pipeline = Pipeline(
algorithm,
problem,
generation_limit=200,
fitness_target=-1e-6,
seed=42,
)
# Run evolution
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)
RL Task Example - Brax Hopper
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import ACT, AGG
pipeline = Pipeline(
algorithm=NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.1,
genome=DefaultGenome(
num_inputs=11,
num_outputs=3,
node_gene=BiasNode(
activation_options=ACT.tanh,
aggregation_options=AGG.sum,
),
output_transform=ACT.tanh,
),
),
problem=BraxEnv(env_name="hopper", max_step=1000),
generation_limit=100,
fitness_target=5000,
)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
File Structure
src/tensorneat/
├── __init__.py # Main package exports [Pipeline, algorithm, genome, problem, common]
├── pipeline.py # Core Pipeline class - orchestrates evolution workflow, handles JIT compilation
├── algorithm/ # Evolution algorithms
│ ├── base.py # BaseAlgorithm abstract class defining algorithm interface
│ ├── neat/
│ │ ├── neat.py # NEAT implementation with population management, crossover/mutation
│ │ └── species.py # SpeciesController for speciation, fitness sharing, stagnation
│ └── hyperneat/
│ ├── hyperneat.py # HyperNEAT algorithm using CPPN to generate substrate weights
│ └── substrate/ # Substrate patterns [BaseSubstrate, FullSubstrate, MLPSubstrate]
├── genome/ # Network genome representation
│ ├── base.py # BaseGenome abstract class with transform/forward methods
│ ├── default.py # DefaultGenome - feedforward networks with topological sort
│ ├── recurrent.py # RecurrentGenome - supports recurrent connections
│ ├── gene/ # Gene components
│ │ ├── node/ # Node genes [DefaultNode, BiasNode, OriginNode]
│ │ └── conn/ # Connection genes [DefaultConn, OriginConn]
│ └── operations/ # Genetic operations
│ ├── mutation/ # Mutation strategies [DefaultMutation, RecurrentMutation]
│ ├── crossover/ # Crossover strategies [DefaultCrossover]
│ └── distance/ # Distance metrics [DefaultDistance]
├── problem/ # Problem definitions
│ ├── base.py # BaseProblem interface for fitness evaluation
│ ├── func_fit/ # Function fitting problems
│ │ ├── xor.py # XOR, XOR3d benchmark problems
│ │ └── custom.py # CustomFuncFit for user-defined functions
│ └── rl/ # Reinforcement learning
│ ├── brax.py # Brax physics environments integration
│ ├── gymnax.py # Gymnax environments integration
│ └── rl_jit.py # Base RL environment with JIT support
└── common/ # Shared utilities
├── state.py # Immutable State class for functional programming
├── stateful_class.py # StatefulBaseClass for state management
├── functions/ # Activation/aggregation functions
│ ├── act_jnp.py # JAX activation functions [sigmoid, tanh, relu, sin, exp, log]
│ ├── agg_jnp.py # JAX aggregation functions [sum, product, mean, max, min]
│ └── manager.py # FunctionManager for registering custom functions
├── graph.py # Graph operations [topological_sort, find_useful_nodes]
├── sympy_tools.py # Symbolic math tools for LaTeX/Python code generation
└── evox_adaptors/ # EvoX integration
├── algorithm_adaptor.py # EvoXAlgorithmAdaptor wrapper
└── tensorneat_monitor.py # TensorNEATMonitor for EvoX workflows
Key Components
Pipeline
Central orchestrator managing the evolution workflow. Handles algorithm-problem integration, JIT compilation, multi-device support, fitness evaluation, and generation tracking.
class Pipeline:
def setup(self, state=State()) -> State # Initialize algorithm and problem
def auto_run(self, state) -> (State, best_genome) # Run evolution loop
def step(self, state) -> (State, pop, fitnesses) # Single evolution step
Algorithm Interface
Base interface that all algorithms (NEAT, HyperNEAT) implement. Defines the ask-tell paradigm for evolutionary algorithms.
class BaseAlgorithm:
def ask(self, state) -> (nodes, conns) # Get population to evaluate
def tell(self, state, fitness) -> State # Update with fitness feedback
def transform(self, state, individual) -> network # Genome to network
def forward(self, state, transformed, inputs) -> outputs # Execute network
Problem Interface
Unified interface for fitness evaluation. All problems implement this contract:
class BaseProblem:
def evaluate(self, state, randkey, act_func, params) -> fitness_scalar
@property
def input_shape(self) -> tuple # e.g., (observation_dim,)
@property
def output_shape(self) -> tuple # e.g., (action_dim,)
def show(self, state, randkey, act_func, params) # Visualize solution
Genome System
Flexible genome representation with customizable genes. Networks are stored as two JAX arrays:
# Nodes array: (max_nodes, node_gene_length)
# Each row: [node_key, activation, aggregation, bias, ...]
nodes = jnp.array([
[0, ACT.tanh, AGG.sum, 0.5, ...], # Input node
[5, ACT.relu, AGG.mean, -0.2, ...], # Hidden node
...
])
# Connections array: (max_conns, conn_gene_length)
# Each row: [in_key, out_key, weight, enabled, ...]
conns = jnp.array([
[0, 5, 0.7, 1.0, ...], # Connection from node 0 to 5
[5, 3, -0.3, 1.0, ...], # Connection from node 5 to 3
...
])
# Genome operations
genome.initialize(state, randkey) # Create initial topology
genome.execute_mutation(state, randkey, nodes, conns, ...)
genome.execute_crossover(state, randkey, parent1, parent2)
genome.forward(state, transformed, inputs) # Network execution
State Management
Functional programming with immutable state. All operations return new state instances:
# State is immutable - operations return new instances
state = State(randkey=key, generation=0)
state = state.register(pop_nodes=nodes, pop_conns=conns) # Add new keys
state = state.update(generation=1) # Update existing keys
# Access state attributes
print(state.generation) # 1
print(state.pop_nodes.shape) # (pop_size, max_nodes, node_length)
# State is pytree-compatible for JAX transformations
@jax.jit
def step(state):
# All state mutations create new instances
# Enables pure functional programming
return new_state
Code Examples
Custom Activation Function
from tensorneat.common import ACT
# Define custom activation
def square(x):
return x ** 2
# Register for use in evolution
ACT.add_func("square", square)
# Use in genome configuration
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
node_gene=BiasNode(
activation_options=[ACT.identity, ACT.square, ACT.tanh],
aggregation_options=[AGG.sum, AGG.product],
),
)
Accessing Network Data
TensorNEAT provides multiple ways to access and export network representations:
# 1. Raw JAX arrays - direct access to internal representation
nodes, conns = algorithm.ask(state) # Returns (max_nodes, gene_length) arrays
# 2. Dictionary format - structured network data
network = algorithm.genome.network_dict(state, nodes, conns)
# Returns:
# {
# "nodes": {0: {"idx": 0, "activation": "tanh", ...}, ...},
# "conns": {(0, 5): {"in": 0, "out": 5, "weight": 0.5, ...}, ...},
# "topo_order": [0, 1, 2, 5, 3], # Execution order
# "topo_layers": [[0, 1], [2], [5, 3]], # Layer grouping
# "useful_nodes": [0, 1, 2, 3] # Nodes affecting outputs
# }
# 3. Visualization - exports to NetworkX internally
algorithm.genome.visualize(network, save_path="network.svg")
# 4. Symbolic representation - SymPy expressions
sympy_res = algorithm.genome.sympy_func(
state, network,
sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
)
latex_code = to_latex_code(*sympy_res) # LaTeX equations
# 5. Executable code generation
python_code = to_python_code(*sympy_res) # Standalone NumPy/JAX code
The internal representation uses pure JAX arrays with NaN padding, making networks fully compatible with JAX transformations like jit
, vmap
, pmap
, and grad
. This enables seamless integration with the broader JAX ecosystem including Optax, Haiku, and Flax.
HyperNEAT Configuration
from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate
# Define substrate topology
substrate = FullSubstrate(
input_coors=((-1, -1), (0, -1), (1, -1)), # Input layer positions
hidden_coors=((-1, 0), (0, 0), (1, 0)), # Hidden layer positions
output_coors=((0, 1),), # Output layer position
)
# NEAT for evolving CPPN
neat = NEAT(
pop_size=10000,
genome=DefaultGenome(
num_inputs=4, # Query coordinates (x1, y1, x2, y2)
num_outputs=1, # Connection weight
output_transform=ACT.tanh,
),
)
# HyperNEAT algorithm
algorithm = HyperNEAT(
substrate=substrate,
neat=neat,
weight_threshold=0.3,
activation=ACT.tanh,
)
Multi-Device with EvoX
from tensorneat.common.evox_adaptors import EvoXAlgorithmAdaptor
from evox import workflows
# Wrap TensorNEAT algorithm for EvoX
evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm)
# Create workflow with multi-device support
workflow = workflows.StdWorkflow(
algorithm=evox_algorithm,
problem=problem,
solution_transforms=[jax.vmap(evox_algorithm.transform)],
)
# Enable multi-device execution
state = workflow.init(key)
state = workflow.enable_multi_devices(state)
# Run distributed evolution
for i in range(generations):
state = workflow.step(state)
Additional Resources
- Documentation: Full API docs in
docs/source/
with Sphinx configuration - Examples: Complete examples in
examples/
covering:- Function fitting: XOR, custom functions, symbolic regression
- RL tasks: Brax (ant, hopper, walker2d), Gymnax (cartpole, mountain car)
- Visualization: Network graphs, symbolic representation
- EvoX integration: Multi-device walker2d example
- Tutorials: Jupyter notebooks in
tutorials/
covering:- Functional programming and state management
- Genome structure and operations
- Pipeline usage patterns
- Tests: Unit tests in
test/
for genome operations, crossover/mutation, substrates
Key Papers
- Original NEAT: Stanley & Miikkulainen (2002)
- HyperNEAT: Stanley et al. (2009)
- TensorNEAT: Wang et al. (2025) - GECCO 2024 Best Paper Award
Dependencies
- JAX >= 0.4.28 (core computation)
- Brax >= 0.10.3 (physics simulation)
- Gymnax >= 0.0.8 (JAX-based gym environments)
- NetworkX >= 3.3 (graph visualization)
- SymPy >= 1.12.1 (symbolic math)
- EvoX v0.9.1-dev (optional, for distributed computing)