Work Stages#

Work stages are the execution units within a Workflows. Each stage encapsulates its own loop with, for example, optimizer, sampler, estimators, writers, etc.

Stages are created through the builder pattern. For example, call VMCWorkStage.builder(cfg, wf) for training, or EvaluationWorkStage.builder(cfg, wf) for evaluation, then configure with configure_* methods and call build().

VMC stage#

class jaqmc.workflow.stage.vmc.VMCStageBuilder(cfg, wavefunction, *, name=None)[source]#

Builder for VMC work stages.

Extends SamplingStageBuilder with optimizer and loss gradient configuration. Call build() to create a fully-configured VMCWorkStage.

build()[source]#

Build a fully-configured VMCWorkStage.

Return type:

VMCWorkStage

Returns:

A fully-configured VMCWorkStage.

configure_estimators(**estimators)[source]#

Configure estimators.

Parameters:

**estimators (EstimatorLike) – Named estimators (e.g. kinetic=..., total=...).

Return type:

None

configure_loss_grads(loss_grads=<class 'jaqmc.estimator.loss_grad.LossAndGrad'>, *, f_log_psi)[source]#

Configure loss gradient estimator and add to estimator dict.

Parameters:
Raises:

ValueError – If configure_estimators() was not called before this method.

Return type:

None

configure_optimizer(*, default, f_log_psi, **kwargs)[source]#

Configure the optimizer.

Parameters:
  • default (str | type) – Object or its module path for the default optimizer.

  • f_log_psi (NumericWavefunctionEvaluate) – Log wavefunction for optimizer wiring.

  • **kwargs – Additional keyword arguments for optimizer wiring.

Raises:

TypeError – If configured optimizer is not an OptimizerLike.

Return type:

None

configure_sample_plan(f_log_amplitude, samplers=None)[source]#

Configure MCMC sampling.

Parameters:
Return type:

None

configure_writers(writers=None)[source]#

Configure writers. If no argument, loads defaults from config.

Parameters:

writers (Writers | None, default: None) – Pre-built writers, or None to load from config.

Return type:

None

class jaqmc.workflow.stage.sampling.SamplingStageBuilder(cfg, wavefunction, *, name=None)[source]#

Base builder for sampling-based work stages.

Provides the progressive configure_* API. After calling build(), the resulting stage is guaranteed fully configured.

Usage:

builder = VMCWorkStage.builder(cfg.scoped("train"), wf)
sampler = cfg.get_module("sampler", "jaqmc.sampler.mcmc:MCMCSampler")
builder.configure_sample_plan(wf.logpsi, {"electrons": sampler})
builder.configure_optimizer(default="jaqmc.optimizer.kfac", f_log_psi=wf.logpsi)
builder.configure_estimators(kinetic=..., potential=..., total=...)
builder.configure_loss_grads(f_log_psi=wf.logpsi)
stage = builder.build()
Parameters:
  • cfg (ConfigManagerLike) – Scoped configuration manager (e.g. cfg.scoped("train")).

  • wavefunction (WavefunctionLike) – Wavefunction instance.

  • name (str | None, default: None) – Stage name. Defaults to cfg.name or the class name.

build()[source]#

Build the fully-configured work stage.

Raises:

NotImplementedError – Subclasses must override this method.

Return type:

WorkStage

config_class[source]#

alias of WorkStageConfig

configure_estimators(**estimators)[source]#

Configure estimators.

Parameters:

**estimators (EstimatorLike) – Named estimators (e.g. kinetic=..., total=...).

Return type:

None

configure_sample_plan(f_log_amplitude, samplers=None)[source]#

Configure MCMC sampling.

Parameters:
Return type:

None

configure_writers(writers=None)[source]#

Configure writers. If no argument, loads defaults from config.

Parameters:

writers (Writers | None, default: None) – Pre-built writers, or None to load from config.

Return type:

None

ensure_configured()[source]#

Resolve defaults and validate. Called by build().

Raises:

ValueError – If configure_estimators() was not called.

Return type:

None

class jaqmc.workflow.stage.vmc.VMCWorkStage(*, sample_plan, estimators, wavefunction, config, name, writers, optimizer)[source]#

Variational Monte Carlo work stage for sampling and training.

Performs MCMC sampling, observable estimation, and parameter optimization. For evaluation without training, use EvaluationWorkStage instead.

Usage:

builder = VMCWorkStage.builder(cfg.scoped("train"), wf)
sampler = cfg.get("sampler", MCMCSampler)
builder.configure_sample_plan(wf.logpsi, {"electrons": sampler})
builder.configure_optimizer(default="jaqmc.optimizer.kfac", f_log_psi=wf.logpsi)
builder.configure_estimators(kinetic=..., potential=..., total=...)
builder.configure_loss_grads(f_log_psi=wf.logpsi)
train = builder.build()
class jaqmc.workflow.stage.vmc.VMCWorkStageConfig(*, check_vma=True, iterations=100, burn_in=100, save_time_interval=600, save_step_interval=1000, stop_on_nan='loss')[source]#

Configuration for VMC work stages.

Parameters:

stop_on_nan (bool | str, default: 'loss') – Abort training when NaN is detected in step statistics. True checks all stat keys, False disables the check, or pass a comma-separated string of specific keys to monitor (e.g. "loss").

Evaluation stage#

class jaqmc.workflow.stage.evaluation.EvalStageBuilder(cfg, wavefunction, *, name=None)[source]#

Builder for evaluation work stages.

Call build() to create a fully-configured EvaluationWorkStage.

build()[source]#

Build a fully-configured EvaluationWorkStage.

Return type:

EvaluationWorkStage

Returns:

A fully-configured EvaluationWorkStage.

configure_estimators(**estimators)[source]#

Configure estimators.

Parameters:

**estimators (EstimatorLike) – Named estimators (e.g. kinetic=..., total=...).

Return type:

None

configure_sample_plan(f_log_amplitude, samplers=None)[source]#

Configure MCMC sampling.

Parameters:
Return type:

None

configure_writers(writers=None)[source]#

Configure writers. If no argument, loads defaults from config.

Parameters:

writers (Writers | None, default: None) – Pre-built writers, or None to load from config.

Return type:

None

class jaqmc.workflow.stage.evaluation.EvaluationWorkStage(*, sample_plan, estimators, wavefunction, config, name, writers)[source]#

Evaluation work stage for sampling and observable estimation.

Runs MCMC sampling and estimator evaluation without parameter updates. Writes per-step statistics through writers and produces a digest.npz summary after all steps complete.

Usage:

builder = EvaluationWorkStage.builder(cfg, wf)
sampler = cfg.get("sampler", MCMCSampler)
builder.configure_sample_plan(wf.logpsi, {"electrons": sampler})
builder.configure_estimators(kinetic=..., potential=..., total=...)
evaluation = builder.build()
class jaqmc.workflow.stage.evaluation.EvaluationWorkStageConfig(*, check_vma=True, iterations=100, burn_in=0, save_time_interval=600, save_step_interval=1000, digest_step_interval=0)[source]#

Configuration for evaluation work stages.

Parameters:

digest_step_interval (int, default: 0) – Log a preview of the accumulated evaluation digest every this many steps. The preview shows running statistics (means, variances) computed from all steps so far. Set to 0 to only print the digest at the end.

State classes#

class jaqmc.workflow.stage.base.RunContext(save_path, restore_path, signal_handler)[source]#

Runtime resources shared with a work stage.

Parameters:
  • save_path (UPath | Path) – Directory where the stage writes checkpoints and outputs.

  • restore_path (UPath | Path) – Checkpoint file or directory used to resume the stage.

  • signal_handler (GracefulKiller) – Handler used to detect graceful termination requests.

class jaqmc.workflow.stage.base.StageAbort(step, state)[source]#

Raised by loop() to request a graceful abort (e.g. NaN detected).

class jaqmc.workflow.stage.base.WorkStage[source]#

Base class for work stages with generator-based run loop.

Subclasses implement loop() as a generator yielding (step, state) tuples. run() handles checkpoint resume/save, signal handling, writers lifecycle, and time-per-step logging.

abstractmethod create_state(rngs, **kwargs)[source]#

Create sharded state for this stage.

Parameters:
  • rngs (PRNGKey) – Initial random seed for all random operations.

  • **kwargs (Any) – Stage-specific arguments (e.g. batched_data).

Return type:

Any

Returns:

Sharded state object.

abstractmethod loop(state, initial_step, rngs)[source]#

Yield (step, state) tuples for each iteration.

Subclasses must implement this generator. Raise StageAbort to request a graceful abort (e.g. NaN detected).

Parameters:
  • state (Any) – Sharded state after checkpoint resume.

  • initial_step (int) – Step to resume from (0 if fresh).

  • rngs (PRNGKey) – Sharded random keys.

Yields:

Tuples of (step_index, updated_state).

restore_checkpoint(checkpoint_path, template, *, prefix='')[source]#

Restore state from a checkpoint.

Parameters:
  • checkpoint_path (str | Path | UPath) – Path to checkpoint file or directory.

  • template (Any) – Template state for deserialization.

  • prefix (str, default: '') – Checkpoint filename prefix to match.

Return type:

Any

Returns:

Restored state.

run(state, context, rngs)[source]#

Execute the full run loop.

Resumes from checkpoint, opens writers, iterates the generator from loop(), handles checkpoint saves and signal-based abort.

Parameters:
  • state (Any) – Initial sharded state.

  • context (RunContext) – Run context with working directory and signal handler.

  • rngs (PRNGKey) – Random key for the run.

Return type:

Any

Returns:

Final state after all iterations.

Raises:
  • SystemExit – On abort (generator-initiated or signal-initiated).

  • StageAbort – Caught internally; not propagated to caller.

class jaqmc.workflow.stage.base.WorkStageConfig(*, check_vma=True, iterations=100, burn_in=0, save_time_interval=600, save_step_interval=1000)[source]#

Base configuration for work stages.

Parameters:
  • iterations (int, default: 100) – Total number of iterations to run.

  • burn_in (int, default: 0) – Sampling iterations to discard before the main loop for MCMC equilibration.

  • save_time_interval (int, default: 600) – Minimum wall-clock seconds between checkpoint saves. A checkpoint is written only when both this and save_step_interval are satisfied.

  • save_step_interval (int, default: 1000) – Save checkpoints only at steps that are multiples of this value.

  • check_vma (bool, default: True) – Enable JAX validity checks during shard_map.

class jaqmc.workflow.stage.sampling.SamplingState(params, batched_data, sampler_state, estimator_state)[source]#

Base state for sampling-based work stages.

Contains the common fields shared across VMC and evaluation stages. Subclasses can add extra fields (e.g. VMCState adds opt_state).

partition()[source]#

Return a matching pytree of PartitionSpec.

Return type:

Self

class jaqmc.workflow.stage.vmc.VMCState(params, batched_data, sampler_state, estimator_state, opt_state)[source]#

State for VMC work stages.