Source code for eztaox.kernels.eqx_utils
"""Utility methods for Equinox modules."""
import equinox as eqx
import jax
from equinox._module import BoundMethod
from jax.tree_util import GetAttrKey
[docs]
def find_param_by_name(module: eqx.Module, name: str) -> list | None:
"""Find a leaf parameter in an Equinox module by name.
Args:
module (eqx.Module): The Equinox module to search in.
name (str): The name of the parameter to find.
Returns:
list | None: The parameter if found, None otherwise.
"""
leaves_with_paths = jax.tree_util.tree_leaves_with_path(module)
leaves = []
for path, leaf in leaves_with_paths:
if path and not isinstance(leaf, BoundMethod):
last_key = path[-1]
if isinstance(last_key, GetAttrKey) and last_key.name == name:
leaves.append(leaf)
return leaves if len(leaves) > 0 else None