Estimators#

API reference for built-in estimators. For background, formulas, and configuration guidance, see the How Estimators Work.

Base classes#

class jaqmc.estimator.base.Estimator[source]#

Base estimator with default no-op implementations.

An estimator computes an observable quantity through a lifecycle with two output paths:

Stats path (per-step statistics → final values):

  1. evaluate (evaluate_local / evaluate_batch) — compute per-walker local values. For example, a kinetic energy estimator returns one energy scalar per walker.

  2. reduce — aggregate local values across walkers into per-step statistics. The default computes mean and variance over walkers (via mean_reduce). The output is what gets written to disk at each step.

  3. finalize_stats — combine per-step statistics (with a leading step dimension) into final physical quantities. This exists because some observables cannot be expressed as a single expectation — they require ratios, products, or other nonlinear combinations of step-level averages (e.g. overlap, polarization, energy gradients).

State path (accumulated state → final values):

  1. finalize_state — extract final observables from accumulated estimator state. Used by estimators that accumulate results directly in state (e.g. histograms) rather than through per-step statistics. Called only during evaluation digest, never inside JIT.

Subclass and override only the methods you need.

Runtime dependencies should be declared as runtime_dep() fields. They can be provided in two ways:

  1. Programmatic: pass directly in the constructor:

    estimator = EuclideanKinetic(f_log_psi=wf.evaluate)
    
  2. Config-driven: use wire(estimator, **context):

    wire(estimator, f_log_psi=wf.evaluate)
    

Subclasses that need to compute derived state from runtime deps should do so in init().

Type Parameters:

DataT – Concrete one-walker Data subtype consumed by this estimator.

evaluate_batch(params, batched_data, prev_local_stats, state, rngs)[source]#

Compute local values over a batch of walkers.

By default, vmaps evaluate_local over the walker dimension. Override directly for estimators that don’t need per-walker vmapping (e.g. histogram aggregation, stats-only estimators).

Parameters:
  • params (Params) – Wavefunction parameters.

  • batched_data (BatchedData[TypeVar(DataT, bound= Data)]) – Batched sampled data.

  • prev_local_stats (Mapping[str, Any]) – Local values produced by earlier estimators in the pipeline (with walker dimension).

  • state (Any) – Estimator state.

  • rngs (PRNGKey) – Random state.

Return type:

tuple[dict[str, Any], Any]

Returns:

A tuple (local_stats, state) where local_stats values have a leading walker dimension.

evaluate_local(params, data, prev_local_stats, state, rngs)[source]#

Compute local values for a single walker.

This is the main method to override. The default evaluate_batch vmaps this over the walker dimension.

Parameters:
  • params (Params) – Wavefunction parameters.

  • data (TypeVar(DataT, bound= Data)) – Data for a single walker.

  • prev_local_stats (Mapping[str, Any]) – Local values produced by earlier estimators in the pipeline (single-walker).

  • state (Any) – Estimator state from init or previous step.

  • rngs (PRNGKey) – Random state.

Return type:

tuple[dict[str, Any], Any]

Returns:

A tuple (local_stats, state) where local_stats maps string keys to per-walker scalar or array values.

finalize_state(state, *, n_steps)[source]#

Extract final observables from accumulated estimator state.

Override this for estimators that accumulate results directly in state rather than through per-step statistics (e.g. histogram estimators). Called only during evaluation digest, never inside JIT.

Parameters:
  • state (Any) – Estimator state after all evaluation steps.

  • n_steps (int) – Total number of evaluation steps completed.

Return type:

dict[str, Any]

Returns:

Final observable values derived from state.

finalize_stats(batched_stats, state)[source]#

Combine per-step statistics into final physical quantities.

Receives the reduce output accumulated over multiple steps, with a leading step dimension on every value. Produces the final observable values.

Override this when the observable requires nonlinear combinations of step-level averages (ratios, products, etc.). The default simply averages over steps.

Parameters:
  • batched_stats (Mapping[str, Any]) – This estimator’s reduce output stacked over steps (values have a leading step dimension).

  • state (Any) – Estimator state.

Return type:

dict[str, Any]

Returns:

Final observable values (step dimension consumed).

init(data, rngs)[source]#

Initialize estimator state from an example data point.

Called once before the first evaluate call.

Return type:

Any

Returns:

State to thread through evaluate calls, or None if no state is needed.

reduce(local_stats)[source]#

Aggregate per-walker local values into per-step statistics.

Called once per step after evaluate_batch. The output is what gets recorded by writers at each step.

The default computes the mean (and variance) over walkers via mean_reduce.

Parameters:

local_stats (Mapping[str, Any]) – This estimator’s output from evaluate_batch (values have a walker dimension).

Return type:

dict[str, Any]

Returns:

Step-level statistics (walker dimension consumed).

class jaqmc.estimator.base.FunctionEstimator(fn)[source]#

Wraps a plain function as an Estimator.

The function is called as evaluate_local; init, reduce, finalize_stats, and finalize_state use the base-class defaults.

evaluate_local(params, data, prev_local_stats, state, rngs)[source]#

Compute local values for a single walker.

This is the main method to override. The default evaluate_batch vmaps this over the walker dimension.

Parameters:
  • params – Wavefunction parameters.

  • data – Data for a single walker.

  • prev_local_stats – Local values produced by earlier estimators in the pipeline (single-walker).

  • state – Estimator state from init or previous step.

  • rngs – Random state.

Returns:

A tuple (local_stats, state) where local_stats maps string keys to per-walker scalar or array values.

class jaqmc.estimator.base.EstimatorPipeline(estimators)[source]#

Chains named estimators into an evaluate → reduce → finalize pipeline.

Each estimator runs in insertion order. Later estimators can read earlier estimators’ local values via prev_local_stats. Key ownership is tracked so that finalize_stats() dispatches each subset of statistics to the correct estimator.

Parameters:

estimators (Mapping[str, EstimatorLike]) – Mapping from estimator name to either an Estimator instance or a plain estimator function. Plain functions are wrapped in FunctionEstimator.

digest(batched_stats, state, *, n_steps)[source]#

Produce the full evaluation digest.

Combines finalize_stats() (from per-step statistics) with each estimator’s finalize_state() (from accumulated state). Call this in evaluation digest, not inside JIT.

Parameters:
  • batched_stats (Mapping[str, Any]) – Flat statistics with a leading batch dimension.

  • state (dict[str, Any]) – Evaluator state after all steps.

  • n_steps (int) – Total number of evaluation steps completed.

Return type:

dict[str, Any]

Returns:

Merged final values from both stats and state paths.

evaluate(params, batched_data, state, rngs)[source]#

Compute averaged local values of the observables.

Parameters:
  • params (Params) – Wavefunction parameters.

  • batched_data (BatchedData) – Batched sampled data.

  • state (dict[str, Any]) – Evaluator state.

  • rngs (PRNGKey) – Random state.

Return type:

tuple[dict[str, Any], dict[str, Any]]

Returns:

A tuple (step_stats, state), where step_stats is a flat dictionary merging each estimator’s reduce() output.

finalize_stats(batched_stats, state)[source]#

Finalize observables from per-step statistics.

batched_stats must have a leading batch/step dimension on every value. This method splits the result by key ownership and dispatches to each estimator’s finalize_stats().

In VMC, pass single-step stats with a batch dimension of 1 (e.g. via tree.map(lambda x: x[None], step_stats)). In evaluation, pass the stacked multi-step stats directly.

Parameters:
  • batched_stats (Mapping[str, Any]) – Flat statistics with a leading batch dimension.

  • state (dict[str, Any]) – Evaluator state.

Return type:

dict[str, Any]

Returns:

Final values for the observables.

Raises:

RuntimeError – If evaluate() was never called (key ownership is unknown).

init(batched_data, rngs)[source]#

Initialize per-estimator state.

Return type:

dict[str, Any]

Returns:

{name: state} dict threaded through evaluate() and finalize_stats().

type jaqmc.estimator.base.EstimatorLike = Estimator | EstimateFn[source]#

Anything accepted where an Estimator is expected.

type jaqmc.estimator.base.EstimateFn = Callable[[Params, Data, Mapping[str, Any], Any, PRNGKey], tuple[dict[str, Any], Any]][source]#

Signature for a plain function usable as an estimator’s evaluate_local.

Built-in estimators#

Kinetic energy#

class jaqmc.estimator.kinetic.EuclideanKinetic(*, mode=forward_laplacian, sparsity_threshold=0, f_log_psi=LateInit(), data_field='electrons')[source]#

Kinetic energy estimator in Euclidean geometry.

The most computationally expensive default energy component. The mode setting controls how the diagonal Hessian is computed and is the main performance knob — see LaplacianMode for trade-offs.

See also

Kinetic energy for the derivation and Laplacian computation details.

Parameters:
  • mode (LaplacianMode, default: forward_laplacian) – Laplacian computation strategy. forward_laplacian is the default for JAX 0.7.1 and later, scan for earlier versions. See LaplacianMode for details.

  • sparsity_threshold (int, default: 0) – Sparsity threshold when using forward_laplacian mode. Always verify numerical correctness before adopting it in production runs.

  • f_log_psi (NumericWavefunctionEvaluate, default: LateInit()) – Log-psi evaluate function (runtime dep).

  • data_field (str, default: 'electrons') – Name of the data field containing positions (runtime dep).

class jaqmc.estimator.kinetic.SphericalKinetic(*, mode=scan, monopole_strength=1.0, radius=None, f_log_psi=LateInit(), data_field='electrons')[source]#

Kinetic energy on a sphere with magnetic monopole.

Uses the Hessian-based calculation following the formulas in section 3.10.3 of Composite Fermions (Jain):

\[\frac{|\Lambda|^2 \psi}{2R^2 \psi} = \frac{1}{2R^2}\left[ -R^2 \frac{\nabla^2\psi}{\psi} + (Q\cot\theta)^2 + 2iQ \frac{\cot\theta}{\sin\theta} \frac{\partial\log\psi}{\partial\phi} \right]\]

Also computes angular momentum observables.

Parameters:
  • mode (LaplacianMode, default: scan) – Laplacian computation mode. scan and fori_loop use a Hessian-based approach; forward_laplacian uses the forward Laplacian.

  • monopole_strength (float, default: 1.0) – Half the magnetic flux (\(Q = \text{flux}/2\)).

  • radius (float | None, default: None) – Sphere radius. Defaults to \(\sqrt{Q}\).

  • f_log_psi (NumericWavefunctionEvaluate, default: LateInit()) – Complex log-psi function (runtime dep).

  • data_field (str, default: 'electrons') – Name of the data field (runtime dep, default "electrons").

class jaqmc.estimator.kinetic.LaplacianMode(*values)[source]#

Modes of calculating the diagonal Hessian for the Laplacian.

scan[source]#

Materializes all iterations via jax.lax.scan() — higher memory, faster compilation. Good default for small to medium systems.

fori_loop[source]#

Runs one iteration at a time via jax.lax.fori_loop() — constant memory, slower compilation. Use when scan causes out-of-memory during compilation.

forward_laplacian[source]#

Forward-mode Laplacian via folx. Can be fastest for large systems. Requires JAX >= 0.7.1.

Ewald summation#

class jaqmc.estimator.ewald.EwaldSum(supercell_lattice, ewald_gmax=200, nlatvec=1)[source]#

Ewald summation for electrostatic energy in periodic systems.

Decomposes the Coulomb interaction into real-space and reciprocal-space series for rapid convergence:

\[V_{\text{Ewald}} = V_{\text{real}} + V_{\text{recip}} + V_{\text{self}} + V_{\text{charged}}\]

All charged particles (electrons and ions) are treated uniformly.

See also

Ewald summation for the full formulation and implementation notes.

Parameters:
  • supercell_lattice (Array) – (3, 3) matrix representing the supercell lattice vectors.

  • ewald_gmax (int, default: 200) – Cutoff for the reciprocal space sum (number of G-vectors in each direction). Determines accuracy of \(V_{\text{recip}}\).

  • nlatvec (int, default: 1) – Cutoff for the real space sum (number of periodic images in each direction). Determines accuracy of \(V_{\text{real}}\).

energy(coords, charges)[source]#

Calculates the total electrostatic energy for a general system.

This method implements a unified Ewald summation where all particles (electrons and ions) are treated as point charges. It computes the total energy as a sum of four components:

\[E_{\text{total}} = E_{\text{real}} + E_{\text{recip}} + E_{\text{self}} + E_{\text{charged}}\]
Parameters:
  • coords (Array) – Particle coordinates (N, 3).

  • charges (Array) – Particle charges (N,).

Return type:

Array

Returns:

Total electrostatic energy.

ECP energy#

class jaqmc.estimator.ecp.estimator.ECPEnergy(*, max_core=2, quadrature_id=None, electrons_field='electrons', atoms_field='atoms', phase_logpsi=LateInit(), ecp_coefficients=LateInit(), atom_symbols=LateInit(), lattice=None, twist=None)[source]#

ECP energy estimator.

Computes both local and nonlocal effective core potential contributions. Added automatically when ecp is set in the system configuration.

  • Local (\(l=0\)): Direct potential energy from the \(l=0\) channel

  • Nonlocal (\(l>0\)): Angular integral weighted by \(V_l(r)\)

The estimator outputs energy:ecp which is included in the total_energy sum automatically.

See also

Pseudopotentials (ECP) for the local/nonlocal decomposition and quadrature details.

Parameters:
  • max_core (int, default: 2) – Maximum number of nearest ECP atoms to consider per electron when evaluating nonlocal integrals. Only the closest max_core ECP atoms contribute per electron; the rest are skipped. Increase this if your system has many ECP atoms in close proximity.

  • quadrature_id (str | None, default: None) – Spherical quadrature rule used to evaluate nonlocal ECP integrals. When None, a default rule is selected automatically.

  • electrons_field (str, default: 'electrons') – Name of electron position field in data.

  • atoms_field (str, default: 'atoms') – Name of atom position field in data.

  • phase_logpsi (WavefunctionEvaluate, default: LateInit()) – Wavefunction ratio function (runtime dep).

  • ecp_coefficients (dict[str, Any], default: LateInit()) – PySCF ECP dict (runtime dep).

  • atom_symbols (list[str], default: LateInit()) – List of element symbols, e.g. ["Li", "H"] (runtime dep).

  • lattice (Array | None, default: None) – Lattice vectors for PBC (runtime dep, optional).

  • twist (Array | None, default: None) – Twist angle for PBC (runtime dep, optional).

Raises:

ValueError – If no atoms have ECP coefficients.

Spin squared#

class jaqmc.estimator.spin.SpinSquared(*, n_up=LateInit(), n_down=LateInit(), phase_logpsi=LateInit())[source]#

Estimator for the total spin operator \(S^2\).

Computes the local value of \(S^2\) for a single walker using

\[S^2_\text{local} = S_z(S_z + 1) + n_\text{minority} - \sum_{i \in \text{minority}} \sum_{j \in \text{majority}} \frac{\Psi(\mathbf{r}_{i \leftrightarrow j})} {\Psi(\mathbf{r})}\]

where \(S_z = |n_\uparrow - n_\downarrow| / 2\) and minority is the spin channel with fewer (or equal) electrons.

Parameters:
  • n_up (int, default: LateInit()) – Number of spin-up electrons.

  • n_down (int, default: LateInit()) – Number of spin-down electrons.

  • phase_logpsi (Any, default: LateInit()) – Wavefunction evaluate function returning (sign, log|psi|) (runtime dep).

Total energy#

class jaqmc.estimator.total_energy.TotalEnergy(*, output_name='total_energy', components=None)[source]#

Estimator that computes total energy from component energies.

Energy components use the energy: prefix convention (e.g. energy:kinetic, energy:potential). When components is None, all keys starting with energy: are summed automatically.

When the total energy is complex (e.g. from a magnetic kinetic energy term), reduce splits it into real and imaginary parts so that variance is computed on the real part only.

Parameters:
  • output_name (str, default: 'total_energy') – Name of the output total energy field.

  • components (list[str] | None, default: None) – Stat keys to sum. When None (default), auto-derives from prev_local_stats keys starting with energy:.

Density#

class jaqmc.estimator.density.CartesianDensity(*, axes=<factory>)[source]#

Electron density along arbitrary directions in Cartesian space.

Each histogram axis is defined by a direction, bin count, and range. Direction vectors are normalized internally, so they need not be unit vectors.

For periodic systems where lattice-aligned density is needed, use FractionalDensity instead.

Parameters:

axes (dict[str, CartesianAxis | None], default: <factory>) – Per-axis configuration keyed by user-chosen labels. Set a value to None to disable an axis inherited from defaults (e.g. when the workflow provides x/y/z but you only want z).

class jaqmc.estimator.density.CartesianAxis(*, direction=(0.0, 0.0, 1.0), bins=50, range=(0.0, 1.0))[source]#

Configuration for one histogram axis in Cartesian coordinates.

Parameters:
  • direction (tuple[float, ...], default: (0.0, 0.0, 1.0)) – Direction vector to project positions onto. Normalized internally — need not be a unit vector.

  • bins (int, default: 50) – Number of histogram bins.

  • range (tuple[float, float], default: (0.0, 1.0)) – (min, max) bounds for the projected coordinate.

class jaqmc.estimator.density.cartesian.CartesianAxis(*, direction=(0.0, 0.0, 1.0), bins=50, range=(0.0, 1.0))[source]#

Configuration for one histogram axis in Cartesian coordinates.

Parameters:
  • direction (tuple[float, ...], default: (0.0, 0.0, 1.0)) – Direction vector to project positions onto. Normalized internally — need not be a unit vector.

  • bins (int, default: 50) – Number of histogram bins.

  • range (tuple[float, float], default: (0.0, 1.0)) – (min, max) bounds for the projected coordinate.

class jaqmc.estimator.density.FractionalDensity(*, axes=<factory>, inv_lattice=LateInit())[source]#

Electron density in fractional (lattice) coordinates.

Converts Cartesian electron positions to fractional coordinates via the inverse lattice matrix:

\[\mathbf{f} = L^{-1}\,\mathbf{r} \mod 1\]

then histograms selected fractional axes. The range is always \([0, 1)\) per axis regardless of cell shape.

For molecules or other open-boundary systems, use CartesianDensity instead.

Parameters:
  • axes (dict[str, FractionalAxis | None], default: <factory>) – Per-axis configuration keyed by user-chosen labels. Set a value to None to disable an axis inherited from defaults.

  • inv_lattice (Array, default: LateInit()) – Inverse lattice matrix, shape (3, 3). Set by the workflow via wire().

class jaqmc.estimator.density.FractionalAxis(*, lattice_index=0, bins=50)[source]#

Configuration for one histogram axis in fractional coordinates.

Parameters:
  • lattice_index (int, default: 0) – Which fractional coordinate to histogram (0, 1, or 2 for the first, second, or third lattice vector).

  • bins (int, default: 50) – Number of histogram bins.

class jaqmc.estimator.density.fractional.FractionalAxis(*, lattice_index=0, bins=50)[source]#

Configuration for one histogram axis in fractional coordinates.

Parameters:
  • lattice_index (int, default: 0) – Which fractional coordinate to histogram (0, 1, or 2 for the first, second, or third lattice vector).

  • bins (int, default: 50) – Number of histogram bins.

class jaqmc.estimator.density.SphericalDensity(*, bins_theta=50, bins_phi=None)[source]#

Electron density on the Haldane sphere.

Accumulates a histogram of electron positions in spherical coordinates. By default only the polar angle \(\theta \in [0, \pi]\) is binned (1-D histogram). Setting bins_phi enables a 2-D \((\theta, \varphi)\) histogram with \(\varphi \in [-\pi, \pi]\).

Parameters:
  • bins_theta (int, default: 50) – Number of bins for the polar angle.

  • bins_phi (int | None, default: None) – Number of bins for the azimuthal angle. None (default) produces a 1-D theta-only histogram.

Loss and gradient#

class jaqmc.estimator.loss_grad.LossAndGrad(*, loss_key='total_energy', clip_scale=5.0, f_log_psi=LateInit())[source]#

Estimator that computes the VMC loss and parameter gradients.

The gradient of the variational energy with respect to wavefunction parameters \(\theta\) is:

\[\nabla_\theta \langle E_L \rangle = 2 \left\langle (E_L - \langle E_L \rangle) \, \nabla_\theta \log|\psi_\theta| \right\rangle\]

The computation proceeds in three stages across the estimator lifecycle:

  1. evaluate_local — evaluates \(\log\psi\) and its parameter gradient for each walker, and reads the loss value.

  2. reduce — applies IQR clipping to the local energies for outlier robustness, then forms the per-walker product \(\nabla\log\psi \cdot E_L^{\text{clipped}}\).

  3. finalize — averages over walkers and subtracts the baseline to assemble the final gradient.

Parameters:
  • loss_key (str, default: 'total_energy') – Key in prev_local_stats to use as the loss.

  • clip_scale (float, default: 5.0) – Multiplier on the interquartile range (IQR) that sets the clipping window for local energies. A walker whose energy falls outside \([Q_1 - s \cdot \text{IQR},\; Q_3 + s \cdot \text{IQR}]\) is clipped to that boundary. Smaller values clip more aggressively, which stabilises gradients but biases the energy estimate. The default of 5.0 is a common choice; set to a large value (e.g. 1e8) to effectively disable clipping.

  • f_log_psi (NumericWavefunctionEvaluate, default: LateInit()) – Log-psi function to differentiate (runtime dep).