altametris.sara.core.base_trainer

Base trainer class for model training.

Provides a common interface for training ML models with support for: - Training and validation loops - Callback system integration - Device management - Progress tracking

Example

>>> class MyTrainer(BaseTrainer):
...     def train(self, dataset_config, epochs=100):
...         for epoch in range(epochs):
...             self.callback_manager.on_epoch_start(epoch=epoch)
...             # Training logic
...             self.callback_manager.on_epoch_end(epoch=epoch, metrics=metrics)

Attributes

Classes

BaseTrainer

Abstract base class for model training.

Module Contents

altametris.sara.core.base_trainer.logger
class altametris.sara.core.base_trainer.BaseTrainer(model: altametris.sara.core.base_model.BaseModel, device: str = 'auto', callbacks: list[altametris.sara.core.base_callback.BaseCallback] | None = None)

Bases: abc.ABC

Abstract base class for model training.

Provides common training infrastructure: - Callback management for training lifecycle events - Device management - Training state tracking - Configuration validation

Parameters:
  • model – Model to train

  • device – Device to train on (“cpu”, “cuda”, “mps”, “auto”)

  • callbacks – List of callbacks for training events

Example

>>> model = MyModel()
>>> trainer = MyTrainer(model=model, device="cuda")
>>> trainer.train(dataset_config="data.yaml", epochs=100)
model
device = 'auto'
callback_manager
_is_training = False
_current_epoch = 0
abstract train(dataset_config: Any, epochs: int = 100, **kwargs: Any) dict[str, Any]

Train the model.

Parameters:
  • dataset_config – Dataset configuration (path, dict, etc.)

  • epochs – Number of training epochs

  • **kwargs – Additional training arguments

Returns:

Training results and metrics

Raises:

TrainingError – If training fails

Note

Must be implemented by subclasses

abstract validate(dataset_config: Any, **kwargs: Any) dict[str, Any]

Validate the model.

Parameters:
  • dataset_config – Validation dataset configuration

  • **kwargs – Additional validation arguments

Returns:

Validation metrics

Raises:

TrainingError – If validation fails

Note

Must be implemented by subclasses

add_callback(callback: altametris.sara.core.base_callback.BaseCallback) None

Add a callback to the trainer.

Parameters:

callback – Callback instance

Example

>>> trainer.add_callback(MyCallback())
remove_callback(callback: altametris.sara.core.base_callback.BaseCallback) None

Remove a callback from the trainer.

Parameters:

callback – Callback instance to remove

property is_training: bool

Check if currently training.

property current_epoch: int

Get current epoch number.

_validate_dataset_config(config: Any) None

Validate dataset configuration.

Parameters:

config – Dataset configuration to validate

Raises:

TrainingError – If configuration is invalid

_validate_training_args(epochs: int, **kwargs: Any) None

Validate training arguments.

Parameters:
  • epochs – Number of epochs

  • **kwargs – Additional arguments

Raises:

TrainingError – If arguments are invalid

save_checkpoint(path: pathlib.Path, epoch: int, metrics: dict[str, Any]) None

Save training checkpoint.

Parameters:
  • path – Path to save checkpoint

  • epoch – Current epoch

  • metrics – Current metrics

Example

>>> trainer.save_checkpoint(
...     path=Path("checkpoints/epoch_10.pt"),
...     epoch=10,
...     metrics={"loss": 0.5}
... )
load_checkpoint(path: pathlib.Path) dict[str, Any]

Load training checkpoint.

Parameters:

path – Path to checkpoint file

Returns:

Checkpoint data

Raises:

TrainingError – If checkpoint cannot be loaded

__repr__() str

String representation of trainer.