Utilities#

Helper functions for array manipulation, function transformations, clipping, units, supercell construction, checkpointing, and multi-device parallelism.

Function transforms#

Wrappers for extracting real/imaginary parts and handling complex-valued JAX functions.

type jaqmc.utils.func_transform.CompatibleFunc = Callable[[Concatenate[Params, DataT, P]], Any][source]#

Callable that receives params first, data second, then extra args.

jaqmc.utils.func_transform.with_real(f)[source]#

Wrap f so that only real parts of its outputs are returned.

Parameters:

f (Callable[[ParamSpec(P)], TypeVar(ReturnT)]) – Callable to wrap.

Return type:

Callable[[ParamSpec(P)], TypeVar(ReturnT)]

Returns:

A wrapped function that applies jnp.real to all outputs.

Type Parameters:
  • P – Parameter specification of f (arguments are preserved).

  • ReturnT – Return type of f before applying jnp.real tree-wise.

Examples

>>> import jax.numpy as jnp
>>> f = lambda x: x + 1j * x**2
>>> float(with_real(f)(2.0))
2.0
jaqmc.utils.func_transform.with_imag(f)[source]#

Wrap f so that only imaginary parts of its outputs are returned.

Parameters:

f (Callable[[ParamSpec(P)], TypeVar(ReturnT)]) – Callable to wrap.

Return type:

Callable[[ParamSpec(P)], TypeVar(ReturnT)]

Returns:

A wrapped function that applies jnp.imag to all outputs.

Type Parameters:
  • P – Parameter specification of f (arguments are preserved).

  • ReturnT – Return type of f before applying jnp.imag tree-wise.

Examples

>>> import jax.numpy as jnp
>>> f = lambda x: x + 1j * x**2
>>> float(with_imag(f)(2.0))
4.0
jaqmc.utils.func_transform.with_output(f, key)[source]#

Wrap f to return only f(...)[key].

Parameters:
Return type:

Callable[[ParamSpec(P)], Any]

Returns:

A wrapped function that extracts key from the output mapping.

Type Parameters:

P – Parameter specification of f (arguments are preserved).

Examples

>>> g = lambda x: {"energy": x**2, "force": -2*x}
>>> with_output(g, "energy")(3.0)
9.0
jaqmc.utils.func_transform.transform_maybe_complex(f, jaxfun, argnums=0)[source]#

Apply a JAX transform to functions with real or complex outputs.

If f has real outputs, this delegates directly to jaxfun. If f has complex outputs, the real and imaginary parts are transformed separately and recombined into a complex result.

Parameters:
  • f (Callable[[ParamSpec(P)], Any]) – Function to transform.

  • jaxfun – JAX transformation such as jax.grad, jax.hessian, or jax.value_and_grad.

  • argnums (int | Sequence[int], default: 0) – Positional argument index or indices to transform with respect to.

Return type:

Callable[[ParamSpec(P)], Any]

Returns:

Wrapped function with the same call signature as f.

Type Parameters:

P – Parameter specification of f.

jaqmc.utils.func_transform.linearize_maybe_complex(f, *args)[source]#

Wraps jax.linearize to handle complex inputs/outputs.

Splits complex values into real and imaginary parts if needed, or passes through if real.

Parameters:
  • f (Callable) – The function to linearize.

  • *args – Arguments to f.

Return type:

tuple[Any, Callable]

Returns:

A tuple (primal, jvp_fn) where primal is the value of f(*args) and jvp_fn is the function that computes the Jacobian-vector product.

jaqmc.utils.func_transform.grad_maybe_complex(f, argnums=0)[source]#

Return jax.grad wrapped to support complex-valued outputs.

Parameters:
  • f (Callable[[ParamSpec(P)], Any]) – Function to differentiate.

  • argnums (int | Sequence[int], default: 0) – Positional argument index or indices to differentiate with respect to.

Return type:

Callable[[ParamSpec(P)], Any]

Returns:

Gradient function with the same call signature as f.

Type Parameters:

P – Parameter specification of f.

jaqmc.utils.func_transform.hessian_maybe_complex(f, argnums=0)[source]#

Return jax.hessian wrapped to support complex-valued outputs.

Parameters:
  • f (Callable[[ParamSpec(P)], Any]) – Function to differentiate twice.

  • argnums (int | Sequence[int], default: 0) – Positional argument index or indices to differentiate with respect to.

Return type:

Callable[[ParamSpec(P)], Any]

Returns:

Hessian function with the same call signature as f.

Type Parameters:

P – Parameter specification of f.

jaqmc.utils.func_transform.transform_with_data(f, key, jaxfun)[source]#

Make grad of functions like f(params, data, *args).

Parameters:
  • f (GenericAlias[TypeVar(DataT, bound= Data), ParamSpec(P)]) – Function to grad.

  • key (str) – With respect to which part in data to take grad.

  • jaxfun – the type of gradient to take, e.g. jax.grad.

Return type:

GenericAlias[TypeVar(DataT, bound= Data), ParamSpec(P)]

Returns:

The grad function.

Type Parameters:
  • DataT – Concrete Data subtype passed through f and the wrapper.

  • P – Extra parameter specification after params and data.

Array utilities#

jaqmc.utils.array.array_partitions(sizes)[source]#

Returns the indices for splitting an array into separate partitions.

Parameters:

sizes (Sequence[int]) – Size of each of N partitions. The dimension of the array along the relevant axis is assumed to be sum(sizes).

Return type:

Sequence[int]

Returns:

Sequence of indices (length len(sizes)-1) at which an array should be split to give the desired partitions.

jaqmc.utils.array.split_nonempty_channels(x, sizes)[source]#

Split an array into non-empty channels along its first axis.

Parameters:
  • x (Array) – Array to split. Its first axis must have length sum(sizes).

  • sizes (Sequence[int]) – Channel sizes along the first axis. Zero-sized channels are omitted from the result.

Return type:

list[Array]

Returns:

Non-empty channel slices of x. If zero or one channel is non-empty, the result is [x].

Raises:

ValueError – If x.shape[0] does not equal sum(sizes).

jaqmc.utils.array.match_first_axis_of(x, target)[source]#

Reshape an array for broadcasting against a higher-rank target.

Parameters:
  • x (Array) – Array whose existing axes should be preserved.

  • target (Array) – Array whose rank determines how many trailing singleton axes are appended to x.

Returns:

x with enough trailing singleton axes to match target.ndim.

Clipping#

jaqmc.utils.clip.iqr_clip(x, scale=100.0)[source]#

Returns the clipped complex observables by applying IQR clip.

The clip is applied on real and imag parts separately.

Return type:

Array

jaqmc.utils.clip.iqr_clip_real(x, scale=100.0)[source]#

Returns the clipped the observables based on interquartile range (IQR).

Return type:

Array

Units#

class jaqmc.utils.units.LengthUnit(*values)[source]#

Length unit used in system configuration.

bohr[source]#

Atomic units of length.

angstrom[source]#

Angstrom units.

Supercell construction#

jaqmc.utils.supercell.get_reciprocal_vectors(lattice)[source]#

Computes reciprocal lattice vectors.

Formula:

\(\mathbf{b}_i = 2\pi (\mathbf{a}^{-1})^T\)

Parameters:

lattice (Array) – The lattice vectors.

Return type:

Array

Returns:

Reciprocal lattice vectors.

jaqmc.utils.supercell.get_supercell_kpts(S, original_reciprocal_vectors)[source]#

Generates supercell k-points in the primitive cell’s first Brillouin zone.

These are the k-points of the primitive cell that fold into the Gamma point of the supercell. They satisfy the condition:

\[\mathbf{k} \cdot \mathbf{S}^{-1} \pmod 1 = 0\]

Algorithm Explanation:

This function finds integer vectors \(\mathbf{n}\) such that the fractional coordinates \(\mathbf{n} \cdot \mathbf{S}^{-1}\) lie within the primitive Brillouin Zone.

For non-diagonal \(\mathbf{S}\) (e.g., transforming an FCC primitive cell to a conventional cell), the valid integers \(\mathbf{n}\) form a skewed volume. The algorithm:

  1. Finds the bounding box of this skewed volume in integer space.

  2. Scans all integers within the box.

  3. Filters for points that map back into the unit cube.

Parameters:
  • S (Array) – Supercell matrix with shape (3, 3).

  • original_reciprocal_vectors (Array) – Reciprocal vectors of the primitive cell (3, 3).

Return type:

Array

Returns:

Array of k-points with shape (N_k, 3).

jaqmc.utils.supercell.get_supercell_copies(latvec, S)[source]#

Calculates translation vectors to tile the supercell with the primitive cell.

The vectors \(\mathbf{R}\) are used to map the primitive cell to the supercell.

Parameters:
  • latvec (Array) – Primitive lattice vectors (3, 3).

  • S (Array) – Supercell matrix with shape (3, 3).

Return type:

Array

Returns:

Array of translation vectors with shape (N_cells, 3).

Checkpointing#

class jaqmc.utils.checkpoint.NumPyCheckpointManager(save_path, restore_path=None, *, prefix='')[source]#

Manage saving and restoring checkpoints as NumPy .npz files.

Checkpoints are stored as PyTrees flattened into named arrays, and can be restored given a reference PyTree that defines the target structure.

save_path[source]#

Base path where new checkpoints are written.

restore_path[source]#

Path used when searching for existing checkpoints.

restore(fallback)[source]#

Restore the latest checkpoint from restore_path if available.

The manager searches for the newest ckpt_*.npz file under restore_path (or uses it directly if it is already a file), and falls back to the provided reference PyTree when nothing can be restored.

Parameters:

fallback (TypeVar(ValueT)) – Reference PyTree to use for structure and default values when no checkpoint exists or all are unreadable.

Returns:

  • step – The initial step of this run (i.e. saved step + 1).

  • restored – The restored PyTree, or fallback if no valid checkpoint is found.

Return type:

tuple[int, TypeVar(ValueT)]

Type Parameters:

ValueT – Reference-tree type that is preserved in the restored value.

static restore_from_file(restore_path, fallback)[source]#

Restore a checkpoint from a single .npz file.

Parameters:
  • restore_path (UPath) – Path to the checkpoint file.

  • fallback (TypeVar(ValueT)) – Reference PyTree used to infer the target structure and provide default values.

Returns:

  • step – The initial step of this run (i.e. saved step + 1).

  • restored – The restored PyTree, or fallback if no valid checkpoint is found.

Return type:

tuple[int, TypeVar(ValueT)]

Type Parameters:

ValueT – Reference-tree type that is preserved in the restored value.

Raises:

ValueError – If restore_path is not a file.

save(step, data)[source]#

Save a checkpoint for the given step.

Parameters:
  • step (int) – Step index associated with this checkpoint.

  • data – PyTree to serialize into the checkpoint.

type jaqmc.utils.checkpoint.PathLike = str | Path | UPath[source]#

Filesystem path accepted by checkpoint readers and writers.

jaqmc.utils.checkpoint.tree_to_npz(tree)[source]#

Save PyTree to npz.

Parameters:

tree (Any) – PyTree to be saved.

Return type:

dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]

Returns:

dict of file name and arrays to be used by np.savez.

jaqmc.utils.checkpoint.tree_from_npz(npf, ref_pytree)[source]#

Restore a PyTree from NPZ or HDF5 storage.

Parameters:
  • npf (Mapping[str, ndarray | Group]) – Mapping-like object for the opened storage file.

  • ref_pytree (Any) – Reference PyTree whose structure is used to rebuild the restored values.

Return type:

Any

Returns:

A PyTree with the same structure as ref_pytree and values loaded from npf.

Configuration helpers#

class jaqmc.utils.config.ConfigManagerLike(*args, **kwargs)[source]#

Protocol implemented by full and scoped configuration managers.

get(name, default)[source]#
Overloads:
  • self, name (str), default (type[ValueT]) → ValueT

  • self, name (str), default (ValueT) → ValueT

Retrieve a configuration value with type safety.

The supported types are:

Parameters:
  • name (str) – The configuration key to retrieve.

  • default (type[TypeVar(ValueT)] | TypeVar(ValueT)) – A default value or a type/class to use as the schema/default.

Returns:

The configuration value, in the same type of default.

Type Parameters:

ValueT – Type inferred from default and preserved in the return value.

Raises:

NotImplementedError – If the type of default is not supported.

get_collection(name, defaults=None, context=None)[source]#

Instantiate a collection of modules from configuration.

Parameters:
  • name (str) – The configuration section name (e.g. “writers”).

  • defaults (dict[str, str | dict] | None, default: None) – A dictionary of {key: default_module_path} for standard items. Users can disable these by setting them to None in config.

  • context (Mapping[str, Any] | None, default: None) – Runtime context dictionary for auto-wiring dependencies.

Returns:

instantiated_object}.

Return type:

dict[str, Any]

get_module(name, default_module='')[source]#
Overloads:
  • self, name (str), default_module (ModuleT) → ModuleT

  • self, name (str), default_module (str) → Any

Instantiate a class or function specified in the configuration.

Parameters:
  • name (str) – Configuration key pointing to the module settings.

  • default_module (str | Callable | type, default: '') – Default module or its path if not specified in config.

Returns:

The initialized object or result of the function call.

Type Parameters:

ModuleT – Module/class/callable type preserved when not using string paths.

property name: str[source]#

Dot-separated scope prefix for this manager.

Multi-device parallelism#

class jaqmc.utils.parallel_jax.DistributedConfig(*, coordinator_address=None, num_processes=1, process_id=0, initialization_timeout=300, wait_second_before_connect=10.0)[source]#

Configuration for initializing JAX distributed runtime.

jaqmc.utils.parallel_jax.make_mesh()[source]#

Create a 1-D device mesh along the batch axis.

Returns:

A JAX mesh spanning all available devices.

jaqmc.utils.parallel_jax.make_sharding(partition)[source]#

Convert partition specs into NamedSharding objects on the default mesh.

Returns:

A pytree of NamedSharding matching the input structure.

jaqmc.utils.parallel_jax.jit_sharded(fn, *, in_specs, out_specs, check_vma=True, donate_argnums=None)[source]#

JIT-compile a function with shard_map in one call.

Parameters:
  • fn – Function to compile.

  • in_specs – Input partition specs for shard_map.

  • out_specs – Output partition specs for shard_map.

  • check_vma (default: True) – Whether to enable validity checks during shard_map.

  • donate_argnums (default: None) – Argument indices to donate (passed to jax.jit).

Returns:

JIT-compiled, shard-mapped function.

jaqmc.utils.parallel_jax.pvary(x)[source]#

Mark x as varying across the batch axis inside shard_map.

Return type:

TypeVar(ValueT)

Returns:

The input value annotated as varying, or unchanged outside shard_map.

Type Parameters:

ValueT – Arbitrary pytree-like value type preserved across the call.

jaqmc.utils.parallel_jax.pmean(x)[source]#

Average x across devices along the batch axis.

Return type:

TypeVar(ValueT)

Returns:

The mean of x across all devices, or x unchanged outside shard_map.

Type Parameters:

ValueT – Arbitrary pytree-like value type preserved across the call.

jaqmc.utils.parallel_jax.all_gather(x)[source]#

Gather arrays from all devices along the batch axis.

Collects arrays sharded across devices and materializes the complete array on each device. This is useful for checkpointing or when you need to access the full dataset on each process.

This function should mimic the behavior of jax.experimental.multihost_utils.process_allgather(x, tiled=True), which is only available for JAX >= 0.8.X

Parameters:

x (TypeVar(ValueT)) – Array or pytree of arrays to gather. Each array should be sharded along the leading dimension corresponding to BATCH_AXIS_NAME.

Return type:

TypeVar(ValueT)

Returns:

Gathered array or pytree with the same structure and shape as input. The sharding is changed from sharded to replicated - each device now has a complete copy of the full array instead of just a shard.

Type Parameters:

ValueT – Arbitrary pytree-like value type preserved across the call.

jaqmc.utils.parallel_jax.addressable_data(x)[source]#

Return the process-local shard of a potentially sharded array.

For any jax.Array (sharded or replicated), returns the addressable (local) portion as a concrete array without sharding metadata. This is needed for init functions (KFAC, samplers) that trace the computation and cannot handle arrays with sharding information.

Parameters:

x – Input value, possibly a sharded jax.Array.

Returns:

The local (addressable) portion of the array, or x unchanged if it is not a jax.Array.

Array type aliases#

type jaqmc.array_types.PRNGKey = Array[source]#
type jaqmc.array_types.Params = ArrayTree[source]#

Parameter PyTree for wavefunction.

type jaqmc.array_types.PyTree = Any[source]#
type jaqmc.array_types.ArrayTree = Array | Sequence[ArrayTree] | Mapping[str, ArrayTree][source]#

Native PyTree whose leaves are JAX arrays.

type jaqmc.array_types.ArrayLikeTree = Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Sequence[ArrayLikeTree] | Mapping[str, ArrayLikeTree][source]#

Native PyTree whose leaves can be converted to JAX arrays.