altametris.sara.core.base_callback ================================== .. py:module:: altametris.sara.core.base_callback .. autoapi-nested-parse:: Base callback system for training lifecycle hooks. Provides an extensible callback mechanism for customizing training behavior without modifying core training logic. Supports: - Training lifecycle hooks (start, end, epoch, batch) - Validation hooks - Custom callback implementation via inheritance - Callback manager for organizing multiple callbacks .. rubric:: Example >>> class MyCallback(BaseCallback): ... def on_epoch_end(self, epoch, metrics): ... print(f"Epoch {epoch}: loss={metrics.get('loss')}") ... >>> manager = CallbackManager([MyCallback()]) >>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5}) Attributes ---------- .. autoapisummary:: altametris.sara.core.base_callback.logger Classes ------- .. autoapisummary:: altametris.sara.core.base_callback.BaseCallback altametris.sara.core.base_callback.CallbackManager Module Contents --------------- .. py:data:: logger .. py:class:: BaseCallback Bases: :py:obj:`abc.ABC` Base class for training callbacks. Callbacks provide hooks into the training lifecycle, allowing custom behavior at key points without modifying training code. All hook methods are optional - override only the ones you need. Hook execution order:: on_train_start() -> for each epoch: -> on_epoch_start() -> for each batch: -> on_batch_start() -> on_batch_end() -> on_validation_start() -> on_validation_end() -> on_epoch_end() on_train_end() .. rubric:: Example >>> class LoggingCallback(BaseCallback): ... def on_epoch_end(self, epoch, metrics): ... print(f"Epoch {epoch}: {metrics}") ... >>> callback = LoggingCallback() >>> callback.on_epoch_end(epoch=1, metrics={"loss": 0.5}) .. py:method:: on_train_start(**kwargs: Any) -> None Called at the start of training. :param \*\*kwargs: Additional context (model, config, etc.) .. rubric:: Example >>> def on_train_start(self, model, epochs, **kwargs): ... print(f"Starting training for {epochs} epochs") .. py:method:: on_train_end(**kwargs: Any) -> None Called at the end of training. :param \*\*kwargs: Additional context (final metrics, model path, etc.) .. rubric:: Example >>> def on_train_end(self, metrics, **kwargs): ... print(f"Training complete: {metrics}") .. py:method:: on_epoch_start(epoch: int, **kwargs: Any) -> None Called at the start of each epoch. :param epoch: Current epoch number (0-indexed) :param \*\*kwargs: Additional context .. rubric:: Example >>> def on_epoch_start(self, epoch, **kwargs): ... print(f"Starting epoch {epoch}") .. py:method:: on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) -> None Called at the end of each epoch. :param epoch: Current epoch number (0-indexed) :param metrics: Metrics computed during epoch (loss, mAP, etc.) :param \*\*kwargs: Additional context .. rubric:: Example >>> def on_epoch_end(self, epoch, metrics, **kwargs): ... if metrics["loss"] < 0.1: ... print("Loss threshold reached!") .. py:method:: on_batch_start(batch: int, **kwargs: Any) -> None Called at the start of each batch. :param batch: Current batch number (0-indexed within epoch) :param \*\*kwargs: Additional context (batch data, etc.) .. rubric:: Example >>> def on_batch_start(self, batch, **kwargs): ... if batch % 100 == 0: ... print(f"Processing batch {batch}") .. py:method:: on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) -> None Called at the end of each batch. :param batch: Current batch number (0-indexed within epoch) :param metrics: Batch metrics (batch_loss, etc.) :param \*\*kwargs: Additional context .. rubric:: Example >>> def on_batch_end(self, batch, metrics, **kwargs): ... if metrics["batch_loss"] > 10.0: ... raise TrainingError("Loss explosion detected") .. py:method:: on_validation_start(**kwargs: Any) -> None Called at the start of validation. :param \*\*kwargs: Additional context .. rubric:: Example >>> def on_validation_start(self, **kwargs): ... print("Starting validation...") .. py:method:: on_validation_end(metrics: dict[str, Any], **kwargs: Any) -> None Called at the end of validation. :param metrics: Validation metrics (val_loss, mAP, etc.) :param \*\*kwargs: Additional context .. rubric:: Example >>> def on_validation_end(self, metrics, **kwargs): ... print(f"Validation mAP: {metrics.get('mAP')}") .. py:class:: CallbackManager(callbacks: Optional[list[BaseCallback]] = None) Manages multiple callbacks and orchestrates their execution. The manager ensures all callbacks receive lifecycle events in order, handles exceptions gracefully, and provides logging for debugging. :param callbacks: List of callback instances to manage .. rubric:: Example >>> callback1 = LoggingCallback() >>> callback2 = MetricsCallback() >>> manager = CallbackManager([callback1, callback2]) >>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5}) .. py:attribute:: callbacks :value: [] .. py:method:: _validate_callbacks() -> None Validate that all callbacks are BaseCallback instances. :raises TrainingError: If any callback is not a BaseCallback instance .. py:method:: add_callback(callback: BaseCallback) -> None Add a callback to the manager. :param callback: Callback instance to add :raises TrainingError: If callback is not a BaseCallback instance .. rubric:: Example >>> manager = CallbackManager() >>> manager.add_callback(LoggingCallback()) .. py:method:: remove_callback(callback: BaseCallback) -> None Remove a callback from the manager. :param callback: Callback instance to remove .. rubric:: Example >>> manager.remove_callback(callback1) .. py:method:: _execute_callback_hook(hook_name: str, *args: Any, **kwargs: Any) -> None Execute a specific hook on all callbacks. :param hook_name: Name of the hook method to call :param \*args: Positional arguments for the hook :param \*\*kwargs: Keyword arguments for the hook :raises TrainingError: If any callback raises an error (re-raised with context) .. py:method:: on_train_start(**kwargs: Any) -> None Execute on_train_start hook on all callbacks. .. py:method:: on_train_end(**kwargs: Any) -> None Execute on_train_end hook on all callbacks. .. py:method:: on_epoch_start(epoch: int, **kwargs: Any) -> None Execute on_epoch_start hook on all callbacks. .. py:method:: on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) -> None Execute on_epoch_end hook on all callbacks. .. py:method:: on_batch_start(batch: int, **kwargs: Any) -> None Execute on_batch_start hook on all callbacks. .. py:method:: on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) -> None Execute on_batch_end hook on all callbacks. .. py:method:: on_validation_start(**kwargs: Any) -> None Execute on_validation_start hook on all callbacks. .. py:method:: on_validation_end(metrics: dict[str, Any], **kwargs: Any) -> None Execute on_validation_end hook on all callbacks. .. py:method:: __len__() -> int Return number of registered callbacks. .. py:method:: __repr__() -> str Return string representation of callback manager.