Optimizers#
JaQMC provides natural gradient optimizers (KFAC, SR) alongside standard optimizers from Optax. Natural gradient methods update the wavefunction in the Hilbert space rather than parameter space, which is typically more stable for VMC. See Optimizers for background on choosing an optimizer.
Configuration#
For optimizer config keys, see the configuration reference: Molecule, Solid, or Hall.
Protocol#
- class jaqmc.optimizer.base.OptimizerLike(*args, **kwargs)[source]#
Protocol for optimizers.
All optimizers expose
initandupdatemethods. The only positional argument required by both isparams(forinit) orgrads, state, params(forupdate). Runtime deps likef_log_psiare wired viaruntime_dep()fields beforeinitis called. Call-time args (batched_data,rngs) are passed via**kwargs.
- class jaqmc.optimizer.base.OptimizerInit(*args, **kwargs)[source]#
- __call__(params, **kwargs)[source]#
Initialize optimizer state.
- Parameters:
params (
Params) – Wavefunction parameters.**kwargs – Optimizer-specific arguments (e.g.,
batched_data,rngs). Runtime deps likef_log_psiare wired viaruntime_dep()fields, not passed here.
- Return type:
- Returns:
Initial optimizer state.
Optimizers provided by JaQMC#
- class jaqmc.optimizer.sr.SROptimizer(*, learning_rate=<factory>, max_norm=<factory>, damping=<factory>, max_cond_num=10000000.0, spring_mu=<factory>, march_beta=<factory>, march_mode='var', eps=1e-08, mixed_precision=True, score_chunk_size=128, score_norm_clip=None, gram_num_chunks=4, gram_dot_prec='F64', prune_inactive=False, f_log_psi=LateInit())[source]#
Robust stochastic reconfiguration optimizer.
Stochastic reconfiguration (SR) is a natural-gradient method that updates parameters in wavefunction space rather than raw parameter space. This is often more stable than first-order optimizers for variational Monte Carlo, especially when small parameter changes can produce uneven changes in the wavefunction.
JaQMC’s SR optimizer uses a robust SR update together with two optional extensions:
SPRING adds momentum through
spring_mu.MARCH adds an adaptive metric through
march_betaandmarch_mode.
Use SR when you want an SR-style natural-gradient update instead of KFAC’s structured approximation. The chunking options trade speed for lower memory use on larger systems.
- Parameters:
learning_rate (
Any, default:<factory>) – Step size (scalar or schedule).max_norm (
Any, default:<factory>) – Constrained update normC(scalar or schedule). IfNone, only the learning-rate scaling is applied.damping (
Any, default:<factory>) – Dampinglambda(scalar or schedule).max_cond_num (
float|None, default:10000000.0) – Maximum condition number for adaptive damping. IfNone, adaptive damping is disabled.spring_mu (
Any, default:<factory>) – SPRING momentum coefficientmu(scalar or schedule). IfNone, SPRING momentum is disabled.march_beta (
Any, default:<factory>) – Decay factor for the MARCH variance accumulator (scalar or schedule). IfNone, the MARCH metric is disabled.march_mode (
Literal['var','diff'], default:'var') – MARCH variance mode."diff"uses update differences and"var"uses score variance along the batch axis.eps (
float, default:1e-08) – Small numerical constant for stability.mixed_precision (
bool, default:True) – Whether to use mixed precision for Gram factorization.score_chunk_size (
int|None, default:128) – Chunk size for score computation. IfNone, full-batch score computation is used.score_norm_clip (
float|None, default:None) – Optional clip value for the mean absolute score per batch row. IfNone, score clipping is disabled.gram_num_chunks (
int|None, default:4) – Number of chunks for Gram matrix computation. IfNone, full-batch Gram computation is used.gram_dot_prec (
str|None, default:'F64') – Precision mode for Gram matrix dot products.prune_inactive (
bool, default:False) – Whether to structurally prune inactive parameter leaves when forming the SR system.
- class jaqmc.optimizer.kfac.kfac.KFACOptimizer(*, learning_rate=<factory>, norm_constraint=0.001, curvature_ema=0.95, l2_reg=0.0, inverse_update_period=1, damping=0.001, f_log_psi=LateInit())[source]#
Kronecker-Factored Approximate Curvature (KFAC) optimizer.
KFAC [arXiv:1503.05671] is a second-order optimization technique. It employs a Kronecker product structure to approximate the natural gradient descent, which considers the geometry of the parameter space during optimization. For more details of the Kronecker product structure, please refer to [arXiv:2507.05127] for comprehensive tutorial.
Natural gradient descent updates for optimizing loss \(\mathcal{L}\) with respect to parameters \(\theta\) have the form \(\delta \theta \propto \mathcal{F}^{-1} \nabla_\theta \mathcal{L}(\theta)\), where \(\mathcal{F}\) is the Fisher Information Matrix (FIM):
\[\mathcal{F}_{i j}=\mathbb{E}_{p(\mathbf{X})}\left[ \frac{\partial \log p(\mathbf{X})}{\partial \theta_i} \frac{\partial \log p(\mathbf{X})}{\partial \theta_j} \right].\]For real-valued wavefunctions, \(p(\mathbf{X}) \propto \psi^2(\mathbf{X})\), it gives the same formula as stochastic reconfiguration (Appendix C of [Phys. Rev. Research 2, 033429 (2020)]).
However, for complex-valued wavefunctions, the natural gradient descent is deviating from stochastic reconfiguration, but KFAC can still be used to approximate stochastic reconfigurations. In stochastic reconfigurations, the parameters updates follow \(\delta \theta \propto \operatorname{Re} [S]^{-1} \operatorname{Re} \left[\nabla_\theta \mathcal{L}(\theta)\right]\), where
\[S_{ij} = \mathbb{E}_{|\psi|^2(\mathbf{X})}\left[ \frac{\partial \log \psi^*}{\partial \theta_i} \frac{\partial \log \psi}{\partial \theta_j} \right].\]We only show the uncentered version (i.e. assuming normalized wavefunction) for simplicity. The reason why we are using \(\operatorname{Re}\) for \(S\) and \(\nabla_\theta \mathcal{L}(\theta)\) is that parameter updates must be real. See “Complex neural networks” paragraph of [Nat. Phys. 20, 1476 (2024)].
The key difference between the \(S\) matrix and FIM is the complex conjugate on the left. The version of KFAC included in JaQMC accounts for this by patching some internal parts of original kfac_jax.
- Parameters:
learning_rate (
Any, default:<factory>) – The learning rate.norm_constraint (
float, default:0.001) – The update is scaled down so that its approximate squared Fisher norm \(v^T F v\) is at most the specified value.curvature_ema (
float, default:0.95) – Decay factor used when calculating the covariance estimate moving averages.l2_reg (
float, default:0.0) – Tell the optimizer what L2 regularization coefficient you are using.inverse_update_period (
int, default:1) – Number of steps in between updating the inverse curvature approximation.damping (
float, default:0.001) – Fixed damping parameter.
Optimizers provided by Optax#
Note
When using Optax, always use the optax:<name> wrapper (e.g. optax:adam) to ensure compatibility with JaQMC’s configuration system.
You can find the full list of optimizers in Optax documentation.
- class jaqmc.optimizer.optax.adam(learning_rate=<factory>, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)[source]#
The Adam optimizer.
Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and second-order moments of the gradients (using suitable exponential moving averages).
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1,b2,epsandeps_rootrespectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
initfunction of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{split}\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\end{split}\]With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with
\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.\]- Parameters:
learning_rate (
Any, default:<factory>) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate().b1 (
float, default:0.9) – Exponential decay rate to track the first moment of past gradients.b2 (
float, default:0.999) – Exponential decay rate to track the second moment of past gradients.eps (
float, default:1e-08) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float, default:0.0) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
nesterov – Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to the
optax.nadam()optimizer, and described in [Dozat 2016].
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adam(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Kingma et al, Adam: A Method for Stochastic Optimization, 2014
Dozat, Incorporating Nesterov Momentum into Adam, 2016
Warning
PyTorch and optax’s implementation follow Algorithm 1 of [Kingma et al. 2014]. Note that TensorFlow used instead the formulation just before Section 2.1 of the paper. See deepmind/optax#571 for more detail.
See also
- class jaqmc.optimizer.optax.lamb(learning_rate=<factory>, b1=0.9, b2=0.999, eps=1e-06, eps_root=0.0, weight_decay=<factory>)[source]#
The LAMB optimizer.
LAMB is a general purpose layer-wise adaptive large batch optimizer designed to provide consistent training performance across a wide range of tasks, including those that use attention-based models (such as Transformers) and ResNet-50. The optimizer is able to work with small and large batch sizes. LAMB was inspired by the LARS learning algorithm.
- Parameters:
learning_rate (
Any, default:<factory>) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate().b1 (
float, default:0.9) – Exponential decay rate to track the first moment of past gradients.b2 (
float, default:0.999) – Exponential decay rate to track the second moment of past gradients.eps (
float, default:1e-06) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float, default:0.0) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.weight_decay (
Any, default:<factory>) – Strength of the weight decay regularization.mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.lamb(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
You et al, Large Batch Optimization for Deep Learning: Training BERT in 76 minutes, 2020
Learning rate schedules#
- class jaqmc.optimizer.schedule.Standard(*, rate=0.05, delay=2000, decay=1)[source]#
Standard learning rate schedule.
\[\text{lr}(t) = \text{rate} \cdot \left( \frac{1}{1+t/\text{delay}}\right )^\text{decay}\]- Parameters:
Examples
>>> s = Standard(rate=0.05, delay=2000, decay=1) >>> s(0) 0.05 >>> s(2000) 0.025