altametris.sara.core.base_callback

Base callback system for training lifecycle hooks.

Provides an extensible callback mechanism for customizing training behavior without modifying core training logic. Supports: - Training lifecycle hooks (start, end, epoch, batch) - Validation hooks - Custom callback implementation via inheritance - Callback manager for organizing multiple callbacks

Example

>>> class MyCallback(BaseCallback):
...     def on_epoch_end(self, epoch, metrics):
...         print(f"Epoch {epoch}: loss={metrics.get('loss')}")
...
>>> manager = CallbackManager([MyCallback()])
>>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5})

Attributes

Classes

BaseCallback

Base class for training callbacks.

CallbackManager

Manages multiple callbacks and orchestrates their execution.

Module Contents

altametris.sara.core.base_callback.logger
class altametris.sara.core.base_callback.BaseCallback

Bases: abc.ABC

Base class for training callbacks.

Callbacks provide hooks into the training lifecycle, allowing custom behavior at key points without modifying training code. All hook methods are optional - override only the ones you need.

Hook execution order:

on_train_start()
-> for each epoch:
    -> on_epoch_start()
    -> for each batch:
        -> on_batch_start()
        -> on_batch_end()
    -> on_validation_start()
    -> on_validation_end()
    -> on_epoch_end()
on_train_end()

Example

>>> class LoggingCallback(BaseCallback):
...     def on_epoch_end(self, epoch, metrics):
...         print(f"Epoch {epoch}: {metrics}")
...
>>> callback = LoggingCallback()
>>> callback.on_epoch_end(epoch=1, metrics={"loss": 0.5})
on_train_start(**kwargs: Any) None

Called at the start of training.

Parameters:

**kwargs – Additional context (model, config, etc.)

Example

>>> def on_train_start(self, model, epochs, **kwargs):
...     print(f"Starting training for {epochs} epochs")
on_train_end(**kwargs: Any) None

Called at the end of training.

Parameters:

**kwargs – Additional context (final metrics, model path, etc.)

Example

>>> def on_train_end(self, metrics, **kwargs):
...     print(f"Training complete: {metrics}")
on_epoch_start(epoch: int, **kwargs: Any) None

Called at the start of each epoch.

Parameters:
  • epoch – Current epoch number (0-indexed)

  • **kwargs – Additional context

Example

>>> def on_epoch_start(self, epoch, **kwargs):
...     print(f"Starting epoch {epoch}")
on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None

Called at the end of each epoch.

Parameters:
  • epoch – Current epoch number (0-indexed)

  • metrics – Metrics computed during epoch (loss, mAP, etc.)

  • **kwargs – Additional context

Example

>>> def on_epoch_end(self, epoch, metrics, **kwargs):
...     if metrics["loss"] < 0.1:
...         print("Loss threshold reached!")
on_batch_start(batch: int, **kwargs: Any) None

Called at the start of each batch.

Parameters:
  • batch – Current batch number (0-indexed within epoch)

  • **kwargs – Additional context (batch data, etc.)

Example

>>> def on_batch_start(self, batch, **kwargs):
...     if batch % 100 == 0:
...         print(f"Processing batch {batch}")
on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None

Called at the end of each batch.

Parameters:
  • batch – Current batch number (0-indexed within epoch)

  • metrics – Batch metrics (batch_loss, etc.)

  • **kwargs – Additional context

Example

>>> def on_batch_end(self, batch, metrics, **kwargs):
...     if metrics["batch_loss"] > 10.0:
...         raise TrainingError("Loss explosion detected")
on_validation_start(**kwargs: Any) None

Called at the start of validation.

Parameters:

**kwargs – Additional context

Example

>>> def on_validation_start(self, **kwargs):
...     print("Starting validation...")
on_validation_end(metrics: dict[str, Any], **kwargs: Any) None

Called at the end of validation.

Parameters:
  • metrics – Validation metrics (val_loss, mAP, etc.)

  • **kwargs – Additional context

Example

>>> def on_validation_end(self, metrics, **kwargs):
...     print(f"Validation mAP: {metrics.get('mAP')}")
class altametris.sara.core.base_callback.CallbackManager(callbacks: list[BaseCallback] | None = None)

Manages multiple callbacks and orchestrates their execution.

The manager ensures all callbacks receive lifecycle events in order, handles exceptions gracefully, and provides logging for debugging.

Parameters:

callbacks – List of callback instances to manage

Example

>>> callback1 = LoggingCallback()
>>> callback2 = MetricsCallback()
>>> manager = CallbackManager([callback1, callback2])
>>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5})
callbacks = []
_validate_callbacks() None

Validate that all callbacks are BaseCallback instances.

Raises:

TrainingError – If any callback is not a BaseCallback instance

add_callback(callback: BaseCallback) None

Add a callback to the manager.

Parameters:

callback – Callback instance to add

Raises:

TrainingError – If callback is not a BaseCallback instance

Example

>>> manager = CallbackManager()
>>> manager.add_callback(LoggingCallback())
remove_callback(callback: BaseCallback) None

Remove a callback from the manager.

Parameters:

callback – Callback instance to remove

Example

>>> manager.remove_callback(callback1)
_execute_callback_hook(hook_name: str, *args: Any, **kwargs: Any) None

Execute a specific hook on all callbacks.

Parameters:
  • hook_name – Name of the hook method to call

  • *args – Positional arguments for the hook

  • **kwargs – Keyword arguments for the hook

Raises:

TrainingError – If any callback raises an error (re-raised with context)

on_train_start(**kwargs: Any) None

Execute on_train_start hook on all callbacks.

on_train_end(**kwargs: Any) None

Execute on_train_end hook on all callbacks.

on_epoch_start(epoch: int, **kwargs: Any) None

Execute on_epoch_start hook on all callbacks.

on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None

Execute on_epoch_end hook on all callbacks.

on_batch_start(batch: int, **kwargs: Any) None

Execute on_batch_start hook on all callbacks.

on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None

Execute on_batch_end hook on all callbacks.

on_validation_start(**kwargs: Any) None

Execute on_validation_start hook on all callbacks.

on_validation_end(metrics: dict[str, Any], **kwargs: Any) None

Execute on_validation_end hook on all callbacks.

__len__() int

Return number of registered callbacks.

__repr__() str

Return string representation of callback manager.