Optimizers#

JaQMC provides natural gradient optimizers (KFAC, SR) alongside standard optimizers from Optax. Natural gradient methods update the wavefunction in the Hilbert space rather than parameter space, which is typically more stable for VMC. See Optimizers for background on choosing an optimizer.

Configuration#

For optimizer config keys, see the configuration reference: Molecule, Solid, or Hall.

Protocol#

class jaqmc.optimizer.base.OptimizerLike(*args, **kwargs)[source]#

Protocol for optimizers.

All optimizers expose init and update methods. The only positional argument required by both is params (for init) or grads, state, params (for update). Runtime deps like f_log_psi are wired via runtime_dep() fields before init is called. Call-time args (batched_data, rngs) are passed via **kwargs.

init: jaqmc.optimizer.base.OptimizerInit[source]#
update: jaqmc.optimizer.base.OptimizerUpdate[source]#
class jaqmc.optimizer.base.OptimizerInit(*args, **kwargs)[source]#
__call__(params, **kwargs)[source]#

Initialize optimizer state.

Parameters:
  • params (Params) – Wavefunction parameters.

  • **kwargs – Optimizer-specific arguments (e.g., batched_data, rngs). Runtime deps like f_log_psi are wired via runtime_dep() fields, not passed here.

Return type:

ArrayTree

Returns:

Initial optimizer state.

class jaqmc.optimizer.base.OptimizerUpdate(*args, **kwargs)[source]#
__call__(grads, state, params, **kwargs)[source]#

Apply optimizer update.

Parameters:
  • grads (Params) – Gradient updates.

  • state (ArrayTree) – Current optimizer state.

  • params (Params) – Current wavefunction parameters.

  • **kwargs – Optimizer-specific arguments (e.g., batched_data, rngs).

Return type:

tuple[Params, ArrayTree]

Returns:

Tuple of (processed updates, new optimizer state).

Optimizers provided by JaQMC#

class jaqmc.optimizer.sr.SROptimizer(*, learning_rate=<factory>, max_norm=<factory>, damping=<factory>, max_cond_num=10000000.0, spring_mu=<factory>, march_beta=<factory>, march_mode='var', eps=1e-08, mixed_precision=True, score_chunk_size=128, score_norm_clip=None, gram_num_chunks=4, gram_dot_prec='F64', prune_inactive=False, f_log_psi=LateInit())[source]#

Robust stochastic reconfiguration optimizer.

Stochastic reconfiguration (SR) is a natural-gradient method that updates parameters in wavefunction space rather than raw parameter space. This is often more stable than first-order optimizers for variational Monte Carlo, especially when small parameter changes can produce uneven changes in the wavefunction.

JaQMC’s SR optimizer uses a robust SR update together with two optional extensions:

  • SPRING adds momentum through spring_mu.

  • MARCH adds an adaptive metric through march_beta and march_mode.

Use SR when you want an SR-style natural-gradient update instead of KFAC’s structured approximation. The chunking options trade speed for lower memory use on larger systems.

Parameters:
  • learning_rate (Any, default: <factory>) – Step size (scalar or schedule).

  • max_norm (Any, default: <factory>) – Constrained update norm C (scalar or schedule). If None, only the learning-rate scaling is applied.

  • damping (Any, default: <factory>) – Damping lambda (scalar or schedule).

  • max_cond_num (float | None, default: 10000000.0) – Maximum condition number for adaptive damping. If None, adaptive damping is disabled.

  • spring_mu (Any, default: <factory>) – SPRING momentum coefficient mu (scalar or schedule). If None, SPRING momentum is disabled.

  • march_beta (Any, default: <factory>) – Decay factor for the MARCH variance accumulator (scalar or schedule). If None, the MARCH metric is disabled.

  • march_mode (Literal['var', 'diff'], default: 'var') – MARCH variance mode. "diff" uses update differences and "var" uses score variance along the batch axis.

  • eps (float, default: 1e-08) – Small numerical constant for stability.

  • mixed_precision (bool, default: True) – Whether to use mixed precision for Gram factorization.

  • score_chunk_size (int | None, default: 128) – Chunk size for score computation. If None, full-batch score computation is used.

  • score_norm_clip (float | None, default: None) – Optional clip value for the mean absolute score per batch row. If None, score clipping is disabled.

  • gram_num_chunks (int | None, default: 4) – Number of chunks for Gram matrix computation. If None, full-batch Gram computation is used.

  • gram_dot_prec (str | None, default: 'F64') – Precision mode for Gram matrix dot products.

  • prune_inactive (bool, default: False) – Whether to structurally prune inactive parameter leaves when forming the SR system.

class jaqmc.optimizer.kfac.kfac.KFACOptimizer(*, learning_rate=<factory>, norm_constraint=0.001, curvature_ema=0.95, l2_reg=0.0, inverse_update_period=1, damping=0.001, f_log_psi=LateInit())[source]#

Kronecker-Factored Approximate Curvature (KFAC) optimizer.

KFAC [arXiv:1503.05671] is a second-order optimization technique. It employs a Kronecker product structure to approximate the natural gradient descent, which considers the geometry of the parameter space during optimization. For more details of the Kronecker product structure, please refer to [arXiv:2507.05127] for comprehensive tutorial.

Natural gradient descent updates for optimizing loss \(\mathcal{L}\) with respect to parameters \(\theta\) have the form \(\delta \theta \propto \mathcal{F}^{-1} \nabla_\theta \mathcal{L}(\theta)\), where \(\mathcal{F}\) is the Fisher Information Matrix (FIM):

\[\mathcal{F}_{i j}=\mathbb{E}_{p(\mathbf{X})}\left[ \frac{\partial \log p(\mathbf{X})}{\partial \theta_i} \frac{\partial \log p(\mathbf{X})}{\partial \theta_j} \right].\]

For real-valued wavefunctions, \(p(\mathbf{X}) \propto \psi^2(\mathbf{X})\), it gives the same formula as stochastic reconfiguration (Appendix C of [Phys. Rev. Research 2, 033429 (2020)]).

However, for complex-valued wavefunctions, the natural gradient descent is deviating from stochastic reconfiguration, but KFAC can still be used to approximate stochastic reconfigurations. In stochastic reconfigurations, the parameters updates follow \(\delta \theta \propto \operatorname{Re} [S]^{-1} \operatorname{Re} \left[\nabla_\theta \mathcal{L}(\theta)\right]\), where

\[S_{ij} = \mathbb{E}_{|\psi|^2(\mathbf{X})}\left[ \frac{\partial \log \psi^*}{\partial \theta_i} \frac{\partial \log \psi}{\partial \theta_j} \right].\]

We only show the uncentered version (i.e. assuming normalized wavefunction) for simplicity. The reason why we are using \(\operatorname{Re}\) for \(S\) and \(\nabla_\theta \mathcal{L}(\theta)\) is that parameter updates must be real. See “Complex neural networks” paragraph of [Nat. Phys. 20, 1476 (2024)].

The key difference between the \(S\) matrix and FIM is the complex conjugate on the left. The version of KFAC included in JaQMC accounts for this by patching some internal parts of original kfac_jax.

Parameters:
  • learning_rate (Any, default: <factory>) – The learning rate.

  • norm_constraint (float, default: 0.001) – The update is scaled down so that its approximate squared Fisher norm \(v^T F v\) is at most the specified value.

  • curvature_ema (float, default: 0.95) – Decay factor used when calculating the covariance estimate moving averages.

  • l2_reg (float, default: 0.0) – Tell the optimizer what L2 regularization coefficient you are using.

  • inverse_update_period (int, default: 1) – Number of steps in between updating the inverse curvature approximation.

  • damping (float, default: 0.001) – Fixed damping parameter.

Optimizers provided by Optax#

Note

When using Optax, always use the optax:<name> wrapper (e.g. optax:adam) to ensure compatibility with JaQMC’s configuration system.

You can find the full list of optimizers in Optax documentation.

class jaqmc.optimizer.optax.adam(learning_rate=<factory>, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)[source]#

The Adam optimizer.

Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and second-order moments of the gradients (using suitable exponential moving averages).

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{split}\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\end{split}\]

With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.\]
Parameters:
  • learning_rate (Any, default: <factory>) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float, default: 0.9) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float, default: 0.999) – Exponential decay rate to track the second moment of past gradients.

  • eps (float, default: 1e-08) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float, default: 0.0) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov – Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to the optax.nadam() optimizer, and described in [Dozat 2016].

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Kingma et al, Adam: A Method for Stochastic Optimization, 2014

Dozat, Incorporating Nesterov Momentum into Adam, 2016

Warning

PyTorch and optax’s implementation follow Algorithm 1 of [Kingma et al. 2014]. Note that TensorFlow used instead the formulation just before Section 2.1 of the paper. See deepmind/optax#571 for more detail.

class jaqmc.optimizer.optax.lamb(learning_rate=<factory>, b1=0.9, b2=0.999, eps=1e-06, eps_root=0.0, weight_decay=<factory>)[source]#

The LAMB optimizer.

LAMB is a general purpose layer-wise adaptive large batch optimizer designed to provide consistent training performance across a wide range of tasks, including those that use attention-based models (such as Transformers) and ResNet-50. The optimizer is able to work with small and large batch sizes. LAMB was inspired by the LARS learning algorithm.

Parameters:
  • learning_rate (Any, default: <factory>) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float, default: 0.9) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float, default: 0.999) – Exponential decay rate to track the second moment of past gradients.

  • eps (float, default: 1e-06) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float, default: 0.0) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay (Any, default: <factory>) – Strength of the weight decay regularization.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lamb(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

References

You et al, Large Batch Optimization for Deep Learning: Training BERT in 76 minutes, 2020

Learning rate schedules#

class jaqmc.optimizer.schedule.Standard(*, rate=0.05, delay=2000, decay=1)[source]#

Standard learning rate schedule.

\[\text{lr}(t) = \text{rate} \cdot \left( \frac{1}{1+t/\text{delay}}\right )^\text{decay}\]
Parameters:
  • rate (float, default: 0.05) – Initial learning rate.

  • delay (float, default: 2000) – Delay in steps before decay starts.

  • decay (float, default: 1) – Decay rate exponent.

Examples

>>> s = Standard(rate=0.05, delay=2000, decay=1)
>>> s(0)
0.05
>>> s(2000)
0.025
class jaqmc.optimizer.schedule.Constant(*, rate=0.05)[source]#

Constant schedule.

Parameters:

rate (float, default: 0.05) – The constant rate.

Examples

>>> c = Constant(rate=0.01)
>>> c(0)
0.01
>>> c(999)
0.01