Custom Writers#
Write a custom writer when you need to record training statistics to a destination beyond console, CSV, or HDF5 — for example, a database, a monitoring dashboard, or a custom binary format.
The Write Lifecycle#
Writers follow a simple lifecycle managed by the training loop:
open()is called once when the stage starts. Set up resources here — open files, establish connections, create tables. In distributed runs,open()runs only on the master process, so you don’t need to guard against multiple writers.write(step, stats)is called every training step.statsis a flat dictionary containing the output of all estimators’reduce()— keys liketotal_energy,pmove,energy:kinetic_var, etc. Values are JAX/NumPy scalars; useself.to_scalar(val)to convert to Python floats if your destination requires it.open()cleanup runs when the stage ends (afteryield). Close file handles, flush buffers, disconnect.Resumption: When training resumes from a checkpoint,
open()receivesinitial_step— the step where training will restart. If your writer persists to a file, truncate any data at or beyond this point so stale entries from a previous (interrupted) run are discarded.
Building a Writer#
Subclass Writer:
from contextlib import contextmanager
from jaqmc.writer.base import Writer
from jaqmc.utils.config import configurable_dataclass
@configurable_dataclass
class MyWriter(Writer):
log_dir: str = "/tmp/logs" # config field — tunable via YAML
open manages the resource lifecycle. All I/O setup goes here — never in __init__. In distributed runs, multiple processes instantiate the writer during configuration, but only the master process enters open(). If you put file creation in __init__, every process would create (and fight over) the same files.
@contextmanager
def open(self, working_dir, stage_name, initial_step=0):
path = working_dir / f"{stage_name}_my_log.txt"
self._file = open(path, "a")
# If resuming, truncate stale entries
try:
yield
finally:
self._file.close()
write records one step’s statistics. Keep it fast — it runs every iteration inside the training loop:
def write(self, step, stats):
energy = self.to_scalar(stats.get("total_energy", float("nan")))
pmove = self.to_scalar(stats.get("pmove", float("nan")))
self._file.write(f"{step},{energy},{pmove}\n")
Getting Started#
ConsoleWriter— simplest writer. Showsto_scalar()usage and selective field display.CSVWriter— file-based writer. Showsopen()with file handle management and header writing.HDF5Writer— chunked array writes. Showsinitial_stephandling for checkpoint truncation.