Work Stages#
Work stages are the execution units within a Workflows. Each stage encapsulates its own loop with, for example, optimizer, sampler, estimators, writers, etc.
Stages are created through the builder pattern. For example, call VMCWorkStage.builder(cfg, wf) for training, or EvaluationWorkStage.builder(cfg, wf) for evaluation, then configure with configure_* methods and call build().
VMC stage#
- class jaqmc.workflow.stage.vmc.VMCStageBuilder(cfg, wavefunction, *, name=None)[source]#
Builder for VMC work stages.
Extends
SamplingStageBuilderwith optimizer and loss gradient configuration. Callbuild()to create a fully-configuredVMCWorkStage.- build()[source]#
Build a fully-configured
VMCWorkStage.- Return type:
- Returns:
A fully-configured
VMCWorkStage.
- configure_estimators(**estimators)[source]#
Configure estimators.
- Parameters:
**estimators (
EstimatorLike) – Named estimators (e.g.kinetic=..., total=...).- Return type:
- configure_loss_grads(loss_grads=<class 'jaqmc.estimator.loss_grad.LossAndGrad'>, *, f_log_psi)[source]#
Configure loss gradient estimator and add to estimator dict.
- Parameters:
f_log_psi (
NumericWavefunctionEvaluate) – Log wavefunction for loss gradient wiring.loss_grads (
type[Estimator] |Estimator|None, default:<class 'jaqmc.estimator.loss_grad.LossAndGrad'>) – A class (resolved from config and wired), an already-wired instance, or None (no loss grads). Defaults toLossAndGrad.
- Raises:
ValueError – If
configure_estimators()was not called before this method.- Return type:
- configure_optimizer(*, default, f_log_psi, **kwargs)[source]#
Configure the optimizer.
- Parameters:
default (
str|type) – Object or its module path for the default optimizer.f_log_psi (
NumericWavefunctionEvaluate) – Log wavefunction for optimizer wiring.**kwargs – Additional keyword arguments for optimizer wiring.
- Raises:
TypeError – If configured optimizer is not an
OptimizerLike.- Return type:
- configure_sample_plan(f_log_amplitude, samplers=None)[source]#
Configure MCMC sampling.
- Parameters:
f_log_amplitude (
NumericWavefunctionEvaluate) – Log amplitude for sampling.samplers (
Mapping[str|tuple[str],SamplerLike] |None, default:None) – Mapping from data field to corresponding sampler.
- Return type:
- class jaqmc.workflow.stage.sampling.SamplingStageBuilder(cfg, wavefunction, *, name=None)[source]#
Base builder for sampling-based work stages.
Provides the progressive
configure_*API. After callingbuild(), the resulting stage is guaranteed fully configured.Usage:
builder = VMCWorkStage.builder(cfg.scoped("train"), wf) sampler = cfg.get_module("sampler", "jaqmc.sampler.mcmc:MCMCSampler") builder.configure_sample_plan(wf.logpsi, {"electrons": sampler}) builder.configure_optimizer(default="jaqmc.optimizer.kfac", f_log_psi=wf.logpsi) builder.configure_estimators(kinetic=..., potential=..., total=...) builder.configure_loss_grads(f_log_psi=wf.logpsi) stage = builder.build()
- Parameters:
cfg (
ConfigManagerLike) – Scoped configuration manager (e.g.cfg.scoped("train")).wavefunction (
WavefunctionLike) – Wavefunction instance.name (
str|None, default:None) – Stage name. Defaults tocfg.nameor the class name.
- build()[source]#
Build the fully-configured work stage.
- Raises:
NotImplementedError – Subclasses must override this method.
- Return type:
- config_class[source]#
alias of
WorkStageConfig
- configure_estimators(**estimators)[source]#
Configure estimators.
- Parameters:
**estimators (
EstimatorLike) – Named estimators (e.g.kinetic=..., total=...).- Return type:
- configure_sample_plan(f_log_amplitude, samplers=None)[source]#
Configure MCMC sampling.
- Parameters:
f_log_amplitude (
NumericWavefunctionEvaluate) – Log amplitude for sampling.samplers (
Mapping[str|tuple[str],SamplerLike] |None, default:None) – Mapping from data field to corresponding sampler.
- Return type:
- configure_writers(writers=None)[source]#
Configure writers. If no argument, loads defaults from config.
- ensure_configured()[source]#
Resolve defaults and validate. Called by
build().- Raises:
ValueError – If
configure_estimators()was not called.- Return type:
- class jaqmc.workflow.stage.vmc.VMCWorkStage(*, sample_plan, estimators, wavefunction, config, name, writers, optimizer)[source]#
Variational Monte Carlo work stage for sampling and training.
Performs MCMC sampling, observable estimation, and parameter optimization. For evaluation without training, use
EvaluationWorkStageinstead.Usage:
builder = VMCWorkStage.builder(cfg.scoped("train"), wf) sampler = cfg.get("sampler", MCMCSampler) builder.configure_sample_plan(wf.logpsi, {"electrons": sampler}) builder.configure_optimizer(default="jaqmc.optimizer.kfac", f_log_psi=wf.logpsi) builder.configure_estimators(kinetic=..., potential=..., total=...) builder.configure_loss_grads(f_log_psi=wf.logpsi) train = builder.build()
Evaluation stage#
- class jaqmc.workflow.stage.evaluation.EvalStageBuilder(cfg, wavefunction, *, name=None)[source]#
Builder for evaluation work stages.
Call
build()to create a fully-configuredEvaluationWorkStage.- build()[source]#
Build a fully-configured
EvaluationWorkStage.- Return type:
- Returns:
A fully-configured
EvaluationWorkStage.
- configure_estimators(**estimators)[source]#
Configure estimators.
- Parameters:
**estimators (
EstimatorLike) – Named estimators (e.g.kinetic=..., total=...).- Return type:
- configure_sample_plan(f_log_amplitude, samplers=None)[source]#
Configure MCMC sampling.
- Parameters:
f_log_amplitude (
NumericWavefunctionEvaluate) – Log amplitude for sampling.samplers (
Mapping[str|tuple[str],SamplerLike] |None, default:None) – Mapping from data field to corresponding sampler.
- Return type:
- class jaqmc.workflow.stage.evaluation.EvaluationWorkStage(*, sample_plan, estimators, wavefunction, config, name, writers)[source]#
Evaluation work stage for sampling and observable estimation.
Runs MCMC sampling and estimator evaluation without parameter updates. Writes per-step statistics through writers and produces a
digest.npzsummary after all steps complete.Usage:
builder = EvaluationWorkStage.builder(cfg, wf) sampler = cfg.get("sampler", MCMCSampler) builder.configure_sample_plan(wf.logpsi, {"electrons": sampler}) builder.configure_estimators(kinetic=..., potential=..., total=...) evaluation = builder.build()
- class jaqmc.workflow.stage.evaluation.EvaluationWorkStageConfig(*, check_vma=True, iterations=100, burn_in=0, save_time_interval=600, save_step_interval=1000, digest_step_interval=0)[source]#
Configuration for evaluation work stages.
- Parameters:
digest_step_interval (
int, default:0) – Log a preview of the accumulated evaluation digest every this many steps. The preview shows running statistics (means, variances) computed from all steps so far. Set to 0 to only print the digest at the end.
State classes#
- class jaqmc.workflow.stage.base.RunContext(save_path, restore_path, signal_handler)[source]#
Runtime resources shared with a work stage.
- class jaqmc.workflow.stage.base.StageAbort(step, state)[source]#
Raised by loop() to request a graceful abort (e.g. NaN detected).
- class jaqmc.workflow.stage.base.WorkStage[source]#
Base class for work stages with generator-based run loop.
Subclasses implement
loop()as a generator yielding(step, state)tuples.run()handles checkpoint resume/save, signal handling, writers lifecycle, and time-per-step logging.- abstractmethod loop(state, initial_step, rngs)[source]#
Yield
(step, state)tuples for each iteration.Subclasses must implement this generator. Raise
StageAbortto request a graceful abort (e.g. NaN detected).
- restore_checkpoint(checkpoint_path, template, *, prefix='')[source]#
Restore state from a checkpoint.
- run(state, context, rngs)[source]#
Execute the full run loop.
Resumes from checkpoint, opens writers, iterates the generator from
loop(), handles checkpoint saves and signal-based abort.- Parameters:
state (
Any) – Initial sharded state.context (
RunContext) – Run context with working directory and signal handler.rngs (
PRNGKey) – Random key for the run.
- Return type:
- Returns:
Final state after all iterations.
- Raises:
SystemExit – On abort (generator-initiated or signal-initiated).
StageAbort – Caught internally; not propagated to caller.
- class jaqmc.workflow.stage.base.WorkStageConfig(*, check_vma=True, iterations=100, burn_in=0, save_time_interval=600, save_step_interval=1000)[source]#
Base configuration for work stages.
- Parameters:
iterations (
int, default:100) – Total number of iterations to run.burn_in (
int, default:0) – Sampling iterations to discard before the main loop for MCMC equilibration.save_time_interval (
int, default:600) – Minimum wall-clock seconds between checkpoint saves. A checkpoint is written only when both this andsave_step_intervalare satisfied.save_step_interval (
int, default:1000) – Save checkpoints only at steps that are multiples of this value.check_vma (
bool, default:True) – Enable JAX validity checks duringshard_map.
- class jaqmc.workflow.stage.sampling.SamplingState(params, batched_data, sampler_state, estimator_state)[source]#
Base state for sampling-based work stages.
Contains the common fields shared across VMC and evaluation stages. Subclasses can add extra fields (e.g.
VMCStateaddsopt_state).- partition()[source]#
Return a matching pytree of
PartitionSpec.- Return type:
Self