altametris.sara.core.base_callback¶
Base callback system for training lifecycle hooks.
Provides an extensible callback mechanism for customizing training behavior without modifying core training logic. Supports: - Training lifecycle hooks (start, end, epoch, batch) - Validation hooks - Custom callback implementation via inheritance - Callback manager for organizing multiple callbacks
Example
>>> class MyCallback(BaseCallback):
... def on_epoch_end(self, epoch, metrics):
... print(f"Epoch {epoch}: loss={metrics.get('loss')}")
...
>>> manager = CallbackManager([MyCallback()])
>>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5})
Attributes¶
Classes¶
Base class for training callbacks. |
|
Manages multiple callbacks and orchestrates their execution. |
Module Contents¶
- altametris.sara.core.base_callback.logger¶
- class altametris.sara.core.base_callback.BaseCallback¶
Bases:
abc.ABCBase class for training callbacks.
Callbacks provide hooks into the training lifecycle, allowing custom behavior at key points without modifying training code. All hook methods are optional - override only the ones you need.
Hook execution order:
on_train_start() -> for each epoch: -> on_epoch_start() -> for each batch: -> on_batch_start() -> on_batch_end() -> on_validation_start() -> on_validation_end() -> on_epoch_end() on_train_end()
Example
>>> class LoggingCallback(BaseCallback): ... def on_epoch_end(self, epoch, metrics): ... print(f"Epoch {epoch}: {metrics}") ... >>> callback = LoggingCallback() >>> callback.on_epoch_end(epoch=1, metrics={"loss": 0.5})
- on_train_start(**kwargs: Any) None¶
Called at the start of training.
- Parameters:
**kwargs – Additional context (model, config, etc.)
Example
>>> def on_train_start(self, model, epochs, **kwargs): ... print(f"Starting training for {epochs} epochs")
- on_train_end(**kwargs: Any) None¶
Called at the end of training.
- Parameters:
**kwargs – Additional context (final metrics, model path, etc.)
Example
>>> def on_train_end(self, metrics, **kwargs): ... print(f"Training complete: {metrics}")
- on_epoch_start(epoch: int, **kwargs: Any) None¶
Called at the start of each epoch.
- Parameters:
epoch – Current epoch number (0-indexed)
**kwargs – Additional context
Example
>>> def on_epoch_start(self, epoch, **kwargs): ... print(f"Starting epoch {epoch}")
- on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Called at the end of each epoch.
- Parameters:
epoch – Current epoch number (0-indexed)
metrics – Metrics computed during epoch (loss, mAP, etc.)
**kwargs – Additional context
Example
>>> def on_epoch_end(self, epoch, metrics, **kwargs): ... if metrics["loss"] < 0.1: ... print("Loss threshold reached!")
- on_batch_start(batch: int, **kwargs: Any) None¶
Called at the start of each batch.
- Parameters:
batch – Current batch number (0-indexed within epoch)
**kwargs – Additional context (batch data, etc.)
Example
>>> def on_batch_start(self, batch, **kwargs): ... if batch % 100 == 0: ... print(f"Processing batch {batch}")
- on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Called at the end of each batch.
- Parameters:
batch – Current batch number (0-indexed within epoch)
metrics – Batch metrics (batch_loss, etc.)
**kwargs – Additional context
Example
>>> def on_batch_end(self, batch, metrics, **kwargs): ... if metrics["batch_loss"] > 10.0: ... raise TrainingError("Loss explosion detected")
- on_validation_start(**kwargs: Any) None¶
Called at the start of validation.
- Parameters:
**kwargs – Additional context
Example
>>> def on_validation_start(self, **kwargs): ... print("Starting validation...")
- on_validation_end(metrics: dict[str, Any], **kwargs: Any) None¶
Called at the end of validation.
- Parameters:
metrics – Validation metrics (val_loss, mAP, etc.)
**kwargs – Additional context
Example
>>> def on_validation_end(self, metrics, **kwargs): ... print(f"Validation mAP: {metrics.get('mAP')}")
- class altametris.sara.core.base_callback.CallbackManager(callbacks: list[BaseCallback] | None = None)¶
Manages multiple callbacks and orchestrates their execution.
The manager ensures all callbacks receive lifecycle events in order, handles exceptions gracefully, and provides logging for debugging.
- Parameters:
callbacks – List of callback instances to manage
Example
>>> callback1 = LoggingCallback() >>> callback2 = MetricsCallback() >>> manager = CallbackManager([callback1, callback2]) >>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5})
- callbacks = []¶
- _validate_callbacks() None¶
Validate that all callbacks are BaseCallback instances.
- Raises:
TrainingError – If any callback is not a BaseCallback instance
- add_callback(callback: BaseCallback) None¶
Add a callback to the manager.
- Parameters:
callback – Callback instance to add
- Raises:
TrainingError – If callback is not a BaseCallback instance
Example
>>> manager = CallbackManager() >>> manager.add_callback(LoggingCallback())
- remove_callback(callback: BaseCallback) None¶
Remove a callback from the manager.
- Parameters:
callback – Callback instance to remove
Example
>>> manager.remove_callback(callback1)
- _execute_callback_hook(hook_name: str, *args: Any, **kwargs: Any) None¶
Execute a specific hook on all callbacks.
- Parameters:
hook_name – Name of the hook method to call
*args – Positional arguments for the hook
**kwargs – Keyword arguments for the hook
- Raises:
TrainingError – If any callback raises an error (re-raised with context)
- on_train_start(**kwargs: Any) None¶
Execute on_train_start hook on all callbacks.
- on_train_end(**kwargs: Any) None¶
Execute on_train_end hook on all callbacks.
- on_epoch_start(epoch: int, **kwargs: Any) None¶
Execute on_epoch_start hook on all callbacks.
- on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Execute on_epoch_end hook on all callbacks.
- on_batch_start(batch: int, **kwargs: Any) None¶
Execute on_batch_start hook on all callbacks.
- on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Execute on_batch_end hook on all callbacks.
- on_validation_start(**kwargs: Any) None¶
Execute on_validation_start hook on all callbacks.
- on_validation_end(metrics: dict[str, Any], **kwargs: Any) None¶
Execute on_validation_end hook on all callbacks.
- __len__() int¶
Return number of registered callbacks.
- __repr__() str¶
Return string representation of callback manager.