Utilities#
Helper functions for array manipulation, function transformations, clipping, units, supercell construction, checkpointing, and multi-device parallelism.
Function transforms#
Wrappers for extracting real/imaginary parts and handling complex-valued JAX functions.
- type jaqmc.utils.func_transform.CompatibleFunc = Callable[[Concatenate[Params, DataT, P]], Any][source]#
Callable that receives
paramsfirst,datasecond, then extra args.
- jaqmc.utils.func_transform.with_real(f)[source]#
Wrap
fso that only real parts of its outputs are returned.- Parameters:
f (
Callable[[ParamSpec(P)],TypeVar(ReturnT)]) – Callable to wrap.- Return type:
- Returns:
A wrapped function that applies
jnp.realto all outputs.- Type Parameters:
P – Parameter specification of
f(arguments are preserved).ReturnT – Return type of
fbefore applyingjnp.realtree-wise.
Examples
>>> import jax.numpy as jnp >>> f = lambda x: x + 1j * x**2 >>> float(with_real(f)(2.0)) 2.0
- jaqmc.utils.func_transform.with_imag(f)[source]#
Wrap
fso that only imaginary parts of its outputs are returned.- Parameters:
f (
Callable[[ParamSpec(P)],TypeVar(ReturnT)]) – Callable to wrap.- Return type:
- Returns:
A wrapped function that applies
jnp.imagto all outputs.- Type Parameters:
P – Parameter specification of
f(arguments are preserved).ReturnT – Return type of
fbefore applyingjnp.imagtree-wise.
Examples
>>> import jax.numpy as jnp >>> f = lambda x: x + 1j * x**2 >>> float(with_imag(f)(2.0)) 4.0
- jaqmc.utils.func_transform.with_output(f, key)[source]#
Wrap
fto return onlyf(...)[key].- Parameters:
- Return type:
- Returns:
A wrapped function that extracts
keyfrom the output mapping.- Type Parameters:
P – Parameter specification of
f(arguments are preserved).
Examples
>>> g = lambda x: {"energy": x**2, "force": -2*x} >>> with_output(g, "energy")(3.0) 9.0
- jaqmc.utils.func_transform.transform_maybe_complex(f, jaxfun, argnums=0)[source]#
Apply a JAX transform to functions with real or complex outputs.
If
fhas real outputs, this delegates directly tojaxfun. Iffhas complex outputs, the real and imaginary parts are transformed separately and recombined into a complex result.- Parameters:
- Return type:
- Returns:
Wrapped function with the same call signature as
f.- Type Parameters:
P – Parameter specification of
f.
- jaqmc.utils.func_transform.linearize_maybe_complex(f, *args)[source]#
Wraps
jax.linearizeto handle complex inputs/outputs.Splits complex values into real and imaginary parts if needed, or passes through if real.
- jaqmc.utils.func_transform.grad_maybe_complex(f, argnums=0)[source]#
Return
jax.gradwrapped to support complex-valued outputs.- Parameters:
- Return type:
- Returns:
Gradient function with the same call signature as
f.- Type Parameters:
P – Parameter specification of
f.
- jaqmc.utils.func_transform.hessian_maybe_complex(f, argnums=0)[source]#
Return
jax.hessianwrapped to support complex-valued outputs.- Parameters:
- Return type:
- Returns:
Hessian function with the same call signature as
f.- Type Parameters:
P – Parameter specification of
f.
- jaqmc.utils.func_transform.transform_with_data(f, key, jaxfun)[source]#
Make grad of functions like
f(params, data, *args).- Parameters:
- Return type:
- Returns:
The grad function.
- Type Parameters:
DataT – Concrete
Datasubtype passed throughfand the wrapper.P – Extra parameter specification after
paramsanddata.
Array utilities#
- jaqmc.utils.array.array_partitions(sizes)[source]#
Returns the indices for splitting an array into separate partitions.
- jaqmc.utils.array.split_nonempty_channels(x, sizes)[source]#
Split an array into non-empty channels along its first axis.
- Parameters:
- Return type:
- Returns:
Non-empty channel slices of
x. If zero or one channel is non-empty, the result is[x].- Raises:
ValueError – If
x.shape[0]does not equalsum(sizes).
Clipping#
Units#
Supercell construction#
- jaqmc.utils.supercell.get_reciprocal_vectors(lattice)[source]#
Computes reciprocal lattice vectors.
- Formula:
\(\mathbf{b}_i = 2\pi (\mathbf{a}^{-1})^T\)
- jaqmc.utils.supercell.get_supercell_kpts(S, original_reciprocal_vectors)[source]#
Generates supercell k-points in the primitive cell’s first Brillouin zone.
These are the k-points of the primitive cell that fold into the Gamma point of the supercell. They satisfy the condition:
\[\mathbf{k} \cdot \mathbf{S}^{-1} \pmod 1 = 0\]Algorithm Explanation:
This function finds integer vectors \(\mathbf{n}\) such that the fractional coordinates \(\mathbf{n} \cdot \mathbf{S}^{-1}\) lie within the primitive Brillouin Zone.
For non-diagonal \(\mathbf{S}\) (e.g., transforming an FCC primitive cell to a conventional cell), the valid integers \(\mathbf{n}\) form a skewed volume. The algorithm:
Finds the bounding box of this skewed volume in integer space.
Scans all integers within the box.
Filters for points that map back into the unit cube.
Checkpointing#
- class jaqmc.utils.checkpoint.NumPyCheckpointManager(save_path, restore_path=None, *, prefix='')[source]#
Manage saving and restoring checkpoints as NumPy
.npzfiles.Checkpoints are stored as PyTrees flattened into named arrays, and can be restored given a reference PyTree that defines the target structure.
- restore(fallback)[source]#
Restore the latest checkpoint from
restore_pathif available.The manager searches for the newest
ckpt_*.npzfile underrestore_path(or uses it directly if it is already a file), and falls back to the provided reference PyTree when nothing can be restored.- Parameters:
fallback (
TypeVar(ValueT)) – Reference PyTree to use for structure and default values when no checkpoint exists or all are unreadable.- Returns:
step – The initial step of this run (i.e. saved step + 1).
restored – The restored PyTree, or
fallbackif no valid checkpoint is found.
- Return type:
- Type Parameters:
ValueT – Reference-tree type that is preserved in the restored value.
- static restore_from_file(restore_path, fallback)[source]#
Restore a checkpoint from a single
.npzfile.- Parameters:
- Returns:
step – The initial step of this run (i.e. saved step + 1).
restored – The restored PyTree, or
fallbackif no valid checkpoint is found.
- Return type:
- Type Parameters:
ValueT – Reference-tree type that is preserved in the restored value.
- Raises:
ValueError – If
restore_pathis not a file.
- type jaqmc.utils.checkpoint.PathLike = str | Path | UPath[source]#
Filesystem path accepted by checkpoint readers and writers.
Configuration helpers#
- class jaqmc.utils.config.ConfigManagerLike(*args, **kwargs)[source]#
Protocol implemented by full and scoped configuration managers.
- get(name, default)[source]#
- Overloads:
self, name (str), default (type[ValueT]) → ValueT
self, name (str), default (ValueT) → ValueT
Retrieve a configuration value with type safety.
The supported types are:
- Parameters:
- Returns:
The configuration value, in the same type of default.
- Type Parameters:
ValueT – Type inferred from
defaultand preserved in the return value.- Raises:
NotImplementedError – If the type of default is not supported.
- get_collection(name, defaults=None, context=None)[source]#
Instantiate a collection of modules from configuration.
- Parameters:
name (
str) – The configuration section name (e.g. “writers”).defaults (
dict[str,str|dict] |None, default:None) – A dictionary of {key: default_module_path} for standard items. Users can disable these by setting them to None in config.context (
Mapping[str,Any] |None, default:None) – Runtime context dictionary for auto-wiring dependencies.
- Returns:
instantiated_object}.
- Return type:
- get_module(name, default_module='')[source]#
- Overloads:
self, name (str), default_module (ModuleT) → ModuleT
self, name (str), default_module (str) → Any
Instantiate a class or function specified in the configuration.
- Parameters:
- Returns:
The initialized object or result of the function call.
- Type Parameters:
ModuleT – Module/class/callable type preserved when not using string paths.
Multi-device parallelism#
- class jaqmc.utils.parallel_jax.DistributedConfig(*, coordinator_address=None, num_processes=1, process_id=0, initialization_timeout=300, wait_second_before_connect=10.0)[source]#
Configuration for initializing JAX distributed runtime.
- jaqmc.utils.parallel_jax.make_mesh()[source]#
Create a 1-D device mesh along the batch axis.
- Returns:
A JAX mesh spanning all available devices.
- jaqmc.utils.parallel_jax.make_sharding(partition)[source]#
Convert partition specs into NamedSharding objects on the default mesh.
- Returns:
A pytree of NamedSharding matching the input structure.
- jaqmc.utils.parallel_jax.jit_sharded(fn, *, in_specs, out_specs, check_vma=True, donate_argnums=None)[source]#
JIT-compile a function with shard_map in one call.
- Parameters:
fn – Function to compile.
in_specs – Input partition specs for
shard_map.out_specs – Output partition specs for
shard_map.check_vma (default:
True) – Whether to enable validity checks during shard_map.donate_argnums (default:
None) – Argument indices to donate (passed tojax.jit).
- Returns:
JIT-compiled, shard-mapped function.
- jaqmc.utils.parallel_jax.pvary(x)[source]#
Mark
xas varying across the batch axis insideshard_map.- Return type:
TypeVar(ValueT)- Returns:
The input value annotated as varying, or unchanged outside shard_map.
- Type Parameters:
ValueT – Arbitrary pytree-like value type preserved across the call.
- jaqmc.utils.parallel_jax.pmean(x)[source]#
Average
xacross devices along the batch axis.- Return type:
TypeVar(ValueT)- Returns:
The mean of
xacross all devices, orxunchanged outside shard_map.- Type Parameters:
ValueT – Arbitrary pytree-like value type preserved across the call.
- jaqmc.utils.parallel_jax.all_gather(x)[source]#
Gather arrays from all devices along the batch axis.
Collects arrays sharded across devices and materializes the complete array on each device. This is useful for checkpointing or when you need to access the full dataset on each process.
This function should mimic the behavior of jax.experimental.multihost_utils.process_allgather(x, tiled=True), which is only available for JAX >= 0.8.X
- Parameters:
x (
TypeVar(ValueT)) – Array or pytree of arrays to gather. Each array should be sharded along the leading dimension corresponding to BATCH_AXIS_NAME.- Return type:
TypeVar(ValueT)- Returns:
Gathered array or pytree with the same structure and shape as input. The sharding is changed from sharded to replicated - each device now has a complete copy of the full array instead of just a shard.
- Type Parameters:
ValueT – Arbitrary pytree-like value type preserved across the call.
- jaqmc.utils.parallel_jax.addressable_data(x)[source]#
Return the process-local shard of a potentially sharded array.
For any
jax.Array(sharded or replicated), returns the addressable (local) portion as a concrete array without sharding metadata. This is needed for init functions (KFAC, samplers) that trace the computation and cannot handle arrays with sharding information.- Parameters:
x – Input value, possibly a sharded
jax.Array.- Returns:
The local (addressable) portion of the array, or x unchanged if it is not a
jax.Array.