Estimators#
API reference for built-in estimators. For background, formulas, and configuration guidance, see the How Estimators Work.
Base classes#
- class jaqmc.estimator.base.Estimator[source]#
Base estimator with default no-op implementations.
An estimator computes an observable quantity through a lifecycle with two output paths:
Stats path (per-step statistics → final values):
evaluate (
evaluate_local/evaluate_batch) — compute per-walker local values. For example, a kinetic energy estimator returns one energy scalar per walker.reduce — aggregate local values across walkers into per-step statistics. The default computes mean and variance over walkers (via
mean_reduce). The output is what gets written to disk at each step.finalize_stats — combine per-step statistics (with a leading step dimension) into final physical quantities. This exists because some observables cannot be expressed as a single expectation — they require ratios, products, or other nonlinear combinations of step-level averages (e.g. overlap, polarization, energy gradients).
State path (accumulated state → final values):
finalize_state — extract final observables from accumulated estimator state. Used by estimators that accumulate results directly in state (e.g. histograms) rather than through per-step statistics. Called only during evaluation digest, never inside JIT.
Subclass and override only the methods you need.
Runtime dependencies should be declared as
runtime_dep()fields. They can be provided in two ways:Programmatic: pass directly in the constructor:
estimator = EuclideanKinetic(f_log_psi=wf.evaluate)
Config-driven: use
wire(estimator, **context):wire(estimator, f_log_psi=wf.evaluate)
Subclasses that need to compute derived state from runtime deps should do so in
init().- Type Parameters:
DataT – Concrete one-walker
Datasubtype consumed by this estimator.
- evaluate_batch(params, batched_data, prev_local_stats, state, rngs)[source]#
Compute local values over a batch of walkers.
By default, vmaps
evaluate_localover the walker dimension. Override directly for estimators that don’t need per-walker vmapping (e.g. histogram aggregation, stats-only estimators).- Parameters:
params (
Params) – Wavefunction parameters.batched_data (
BatchedData[TypeVar(DataT, bound=Data)]) – Batched sampled data.prev_local_stats (
Mapping[str,Any]) – Local values produced by earlier estimators in the pipeline (with walker dimension).state (
Any) – Estimator state.rngs (
PRNGKey) – Random state.
- Return type:
- Returns:
A tuple
(local_stats, state)wherelocal_statsvalues have a leading walker dimension.
- evaluate_local(params, data, prev_local_stats, state, rngs)[source]#
Compute local values for a single walker.
This is the main method to override. The default
evaluate_batchvmaps this over the walker dimension.- Parameters:
params (
Params) – Wavefunction parameters.data (
TypeVar(DataT, bound=Data)) – Data for a single walker.prev_local_stats (
Mapping[str,Any]) – Local values produced by earlier estimators in the pipeline (single-walker).state (
Any) – Estimator state frominitor previous step.rngs (
PRNGKey) – Random state.
- Return type:
- Returns:
A tuple
(local_stats, state)wherelocal_statsmaps string keys to per-walker scalar or array values.
- finalize_state(state, *, n_steps)[source]#
Extract final observables from accumulated estimator state.
Override this for estimators that accumulate results directly in state rather than through per-step statistics (e.g. histogram estimators). Called only during evaluation digest, never inside JIT.
- finalize_stats(batched_stats, state)[source]#
Combine per-step statistics into final physical quantities.
Receives the
reduceoutput accumulated over multiple steps, with a leading step dimension on every value. Produces the final observable values.Override this when the observable requires nonlinear combinations of step-level averages (ratios, products, etc.). The default simply averages over steps.
- init(data, rngs)[source]#
Initialize estimator state from an example data point.
Called once before the first
evaluatecall.- Return type:
- Returns:
State to thread through evaluate calls, or
Noneif no state is needed.
- class jaqmc.estimator.base.FunctionEstimator(fn)[source]#
Wraps a plain function as an
Estimator.The function is called as
evaluate_local;init,reduce,finalize_stats, andfinalize_stateuse the base-class defaults.- evaluate_local(params, data, prev_local_stats, state, rngs)[source]#
Compute local values for a single walker.
This is the main method to override. The default
evaluate_batchvmaps this over the walker dimension.- Parameters:
params – Wavefunction parameters.
data – Data for a single walker.
prev_local_stats – Local values produced by earlier estimators in the pipeline (single-walker).
state – Estimator state from
initor previous step.rngs – Random state.
- Returns:
A tuple
(local_stats, state)wherelocal_statsmaps string keys to per-walker scalar or array values.
- class jaqmc.estimator.base.EstimatorPipeline(estimators)[source]#
Chains named estimators into an evaluate → reduce → finalize pipeline.
Each estimator runs in insertion order. Later estimators can read earlier estimators’ local values via
prev_local_stats. Key ownership is tracked so thatfinalize_stats()dispatches each subset of statistics to the correct estimator.- Parameters:
estimators (
Mapping[str,EstimatorLike]) – Mapping from estimator name to either anEstimatorinstance or a plain estimator function. Plain functions are wrapped inFunctionEstimator.
- digest(batched_stats, state, *, n_steps)[source]#
Produce the full evaluation digest.
Combines
finalize_stats()(from per-step statistics) with each estimator’sfinalize_state()(from accumulated state). Call this in evaluation digest, not inside JIT.- Parameters:
- Return type:
- Returns:
Merged final values from both stats and state paths.
- evaluate(params, batched_data, state, rngs)[source]#
Compute averaged local values of the observables.
- Parameters:
- Return type:
- Returns:
A tuple
(step_stats, state), wherestep_statsis a flat dictionary merging each estimator’sreduce()output.
- finalize_stats(batched_stats, state)[source]#
Finalize observables from per-step statistics.
batched_statsmust have a leading batch/step dimension on every value. This method splits the result by key ownership and dispatches to each estimator’sfinalize_stats().In VMC, pass single-step stats with a batch dimension of 1 (e.g. via
tree.map(lambda x: x[None], step_stats)). In evaluation, pass the stacked multi-step stats directly.
- init(batched_data, rngs)[source]#
Initialize per-estimator state.
- Return type:
- Returns:
{name: state}dict threaded throughevaluate()andfinalize_stats().
- type jaqmc.estimator.base.EstimatorLike = Estimator | EstimateFn[source]#
Anything accepted where an
Estimatoris expected.
Built-in estimators#
Kinetic energy#
- class jaqmc.estimator.kinetic.EuclideanKinetic(*, mode=forward_laplacian, sparsity_threshold=0, f_log_psi=LateInit(), data_field='electrons')[source]#
Kinetic energy estimator in Euclidean geometry.
The most computationally expensive default energy component. The
modesetting controls how the diagonal Hessian is computed and is the main performance knob — seeLaplacianModefor trade-offs.See also
Kinetic energy for the derivation and Laplacian computation details.
- Parameters:
mode (
LaplacianMode, default:forward_laplacian) – Laplacian computation strategy.forward_laplacianis the default for JAX 0.7.1 and later,scanfor earlier versions. SeeLaplacianModefor details.sparsity_threshold (
int, default:0) – Sparsity threshold when usingforward_laplacianmode. Always verify numerical correctness before adopting it in production runs.f_log_psi (
NumericWavefunctionEvaluate, default:LateInit()) – Log-psi evaluate function (runtime dep).data_field (
str, default:'electrons') – Name of the data field containing positions (runtime dep).
- class jaqmc.estimator.kinetic.SphericalKinetic(*, mode=scan, monopole_strength=1.0, radius=None, f_log_psi=LateInit(), data_field='electrons')[source]#
Kinetic energy on a sphere with magnetic monopole.
Uses the Hessian-based calculation following the formulas in section 3.10.3 of Composite Fermions (Jain):
\[\frac{|\Lambda|^2 \psi}{2R^2 \psi} = \frac{1}{2R^2}\left[ -R^2 \frac{\nabla^2\psi}{\psi} + (Q\cot\theta)^2 + 2iQ \frac{\cot\theta}{\sin\theta} \frac{\partial\log\psi}{\partial\phi} \right]\]Also computes angular momentum observables.
- Parameters:
mode (
LaplacianMode, default:scan) – Laplacian computation mode.scanandfori_loopuse a Hessian-based approach;forward_laplacianuses the forward Laplacian.monopole_strength (
float, default:1.0) – Half the magnetic flux (\(Q = \text{flux}/2\)).radius (
float|None, default:None) – Sphere radius. Defaults to \(\sqrt{Q}\).f_log_psi (
NumericWavefunctionEvaluate, default:LateInit()) – Complex log-psi function (runtime dep).data_field (
str, default:'electrons') – Name of the data field (runtime dep, default"electrons").
- class jaqmc.estimator.kinetic.LaplacianMode(*values)[source]#
Modes of calculating the diagonal Hessian for the Laplacian.
- scan[source]#
Materializes all iterations via
jax.lax.scan()— higher memory, faster compilation. Good default for small to medium systems.
- fori_loop[source]#
Runs one iteration at a time via
jax.lax.fori_loop()— constant memory, slower compilation. Use whenscancauses out-of-memory during compilation.
Ewald summation#
- class jaqmc.estimator.ewald.EwaldSum(supercell_lattice, ewald_gmax=200, nlatvec=1)[source]#
Ewald summation for electrostatic energy in periodic systems.
Decomposes the Coulomb interaction into real-space and reciprocal-space series for rapid convergence:
\[V_{\text{Ewald}} = V_{\text{real}} + V_{\text{recip}} + V_{\text{self}} + V_{\text{charged}}\]All charged particles (electrons and ions) are treated uniformly.
See also
Ewald summation for the full formulation and implementation notes.
- Parameters:
supercell_lattice (
Array) – (3, 3) matrix representing the supercell lattice vectors.ewald_gmax (
int, default:200) – Cutoff for the reciprocal space sum (number of G-vectors in each direction). Determines accuracy of \(V_{\text{recip}}\).nlatvec (
int, default:1) – Cutoff for the real space sum (number of periodic images in each direction). Determines accuracy of \(V_{\text{real}}\).
- energy(coords, charges)[source]#
Calculates the total electrostatic energy for a general system.
This method implements a unified Ewald summation where all particles (electrons and ions) are treated as point charges. It computes the total energy as a sum of four components:
\[E_{\text{total}} = E_{\text{real}} + E_{\text{recip}} + E_{\text{self}} + E_{\text{charged}}\]
ECP energy#
- class jaqmc.estimator.ecp.estimator.ECPEnergy(*, max_core=2, quadrature_id=None, electrons_field='electrons', atoms_field='atoms', phase_logpsi=LateInit(), ecp_coefficients=LateInit(), atom_symbols=LateInit(), lattice=None, twist=None)[source]#
ECP energy estimator.
Computes both local and nonlocal effective core potential contributions. Added automatically when
ecpis set in the system configuration.Local (\(l=0\)): Direct potential energy from the \(l=0\) channel
Nonlocal (\(l>0\)): Angular integral weighted by \(V_l(r)\)
The estimator outputs
energy:ecpwhich is included in thetotal_energysum automatically.See also
Pseudopotentials (ECP) for the local/nonlocal decomposition and quadrature details.
- Parameters:
max_core (
int, default:2) – Maximum number of nearest ECP atoms to consider per electron when evaluating nonlocal integrals. Only the closestmax_coreECP atoms contribute per electron; the rest are skipped. Increase this if your system has many ECP atoms in close proximity.quadrature_id (
str|None, default:None) – Spherical quadrature rule used to evaluate nonlocal ECP integrals. WhenNone, a default rule is selected automatically.electrons_field (
str, default:'electrons') – Name of electron position field in data.atoms_field (
str, default:'atoms') – Name of atom position field in data.phase_logpsi (
WavefunctionEvaluate, default:LateInit()) – Wavefunction ratio function (runtime dep).ecp_coefficients (
dict[str,Any], default:LateInit()) – PySCF ECP dict (runtime dep).atom_symbols (
list[str], default:LateInit()) – List of element symbols, e.g.["Li", "H"](runtime dep).lattice (
Array|None, default:None) – Lattice vectors for PBC (runtime dep, optional).twist (
Array|None, default:None) – Twist angle for PBC (runtime dep, optional).
- Raises:
ValueError – If no atoms have ECP coefficients.
Spin squared#
- class jaqmc.estimator.spin.SpinSquared(*, n_up=LateInit(), n_down=LateInit(), phase_logpsi=LateInit())[source]#
Estimator for the total spin operator \(S^2\).
Computes the local value of \(S^2\) for a single walker using
\[S^2_\text{local} = S_z(S_z + 1) + n_\text{minority} - \sum_{i \in \text{minority}} \sum_{j \in \text{majority}} \frac{\Psi(\mathbf{r}_{i \leftrightarrow j})} {\Psi(\mathbf{r})}\]where \(S_z = |n_\uparrow - n_\downarrow| / 2\) and minority is the spin channel with fewer (or equal) electrons.
Total energy#
- class jaqmc.estimator.total_energy.TotalEnergy(*, output_name='total_energy', components=None)[source]#
Estimator that computes total energy from component energies.
Energy components use the
energy:prefix convention (e.g.energy:kinetic,energy:potential). Whencomponentsis None, all keys starting withenergy:are summed automatically.When the total energy is complex (e.g. from a magnetic kinetic energy term),
reducesplits it into real and imaginary parts so that variance is computed on the real part only.
Density#
- class jaqmc.estimator.density.CartesianDensity(*, axes=<factory>)[source]#
Electron density along arbitrary directions in Cartesian space.
Each histogram axis is defined by a direction, bin count, and range. Direction vectors are normalized internally, so they need not be unit vectors.
For periodic systems where lattice-aligned density is needed, use
FractionalDensityinstead.- Parameters:
axes (
dict[str,CartesianAxis|None], default:<factory>) – Per-axis configuration keyed by user-chosen labels. Set a value toNoneto disable an axis inherited from defaults (e.g. when the workflow provides x/y/z but you only want z).
- class jaqmc.estimator.density.CartesianAxis(*, direction=(0.0, 0.0, 1.0), bins=50, range=(0.0, 1.0))[source]#
Configuration for one histogram axis in Cartesian coordinates.
- Parameters:
direction (
tuple[float,...], default:(0.0, 0.0, 1.0)) – Direction vector to project positions onto. Normalized internally — need not be a unit vector.bins (
int, default:50) – Number of histogram bins.range (
tuple[float,float], default:(0.0, 1.0)) –(min, max)bounds for the projected coordinate.
- class jaqmc.estimator.density.cartesian.CartesianAxis(*, direction=(0.0, 0.0, 1.0), bins=50, range=(0.0, 1.0))[source]#
Configuration for one histogram axis in Cartesian coordinates.
- Parameters:
direction (
tuple[float,...], default:(0.0, 0.0, 1.0)) – Direction vector to project positions onto. Normalized internally — need not be a unit vector.bins (
int, default:50) – Number of histogram bins.range (
tuple[float,float], default:(0.0, 1.0)) –(min, max)bounds for the projected coordinate.
- class jaqmc.estimator.density.FractionalDensity(*, axes=<factory>, inv_lattice=LateInit())[source]#
Electron density in fractional (lattice) coordinates.
Converts Cartesian electron positions to fractional coordinates via the inverse lattice matrix:
\[\mathbf{f} = L^{-1}\,\mathbf{r} \mod 1\]then histograms selected fractional axes. The range is always \([0, 1)\) per axis regardless of cell shape.
For molecules or other open-boundary systems, use
CartesianDensityinstead.- Parameters:
- class jaqmc.estimator.density.FractionalAxis(*, lattice_index=0, bins=50)[source]#
Configuration for one histogram axis in fractional coordinates.
- class jaqmc.estimator.density.fractional.FractionalAxis(*, lattice_index=0, bins=50)[source]#
Configuration for one histogram axis in fractional coordinates.
- class jaqmc.estimator.density.SphericalDensity(*, bins_theta=50, bins_phi=None)[source]#
Electron density on the Haldane sphere.
Accumulates a histogram of electron positions in spherical coordinates. By default only the polar angle \(\theta \in [0, \pi]\) is binned (1-D histogram). Setting
bins_phienables a 2-D \((\theta, \varphi)\) histogram with \(\varphi \in [-\pi, \pi]\).
Loss and gradient#
- class jaqmc.estimator.loss_grad.LossAndGrad(*, loss_key='total_energy', clip_scale=5.0, f_log_psi=LateInit())[source]#
Estimator that computes the VMC loss and parameter gradients.
The gradient of the variational energy with respect to wavefunction parameters \(\theta\) is:
\[\nabla_\theta \langle E_L \rangle = 2 \left\langle (E_L - \langle E_L \rangle) \, \nabla_\theta \log|\psi_\theta| \right\rangle\]The computation proceeds in three stages across the estimator lifecycle:
evaluate_local— evaluates \(\log\psi\) and its parameter gradient for each walker, and reads the loss value.reduce— applies IQR clipping to the local energies for outlier robustness, then forms the per-walker product \(\nabla\log\psi \cdot E_L^{\text{clipped}}\).finalize— averages over walkers and subtracts the baseline to assemble the final gradient.
- Parameters:
loss_key (
str, default:'total_energy') – Key in prev_local_stats to use as the loss.clip_scale (
float, default:5.0) – Multiplier on the interquartile range (IQR) that sets the clipping window for local energies. A walker whose energy falls outside \([Q_1 - s \cdot \text{IQR},\; Q_3 + s \cdot \text{IQR}]\) is clipped to that boundary. Smaller values clip more aggressively, which stabilises gradients but biases the energy estimate. The default of 5.0 is a common choice; set to a large value (e.g. 1e8) to effectively disable clipping.f_log_psi (
NumericWavefunctionEvaluate, default:LateInit()) – Log-psi function to differentiate (runtime dep).