altametris.sara.core.base_trainer¶
Base trainer class for model training.
Provides a common interface for training ML models with support for: - Training and validation loops - Callback system integration - Device management - Progress tracking
Example
>>> class MyTrainer(BaseTrainer):
... def train(self, dataset_config, epochs=100):
... for epoch in range(epochs):
... self.callback_manager.on_epoch_start(epoch=epoch)
... # Training logic
... self.callback_manager.on_epoch_end(epoch=epoch, metrics=metrics)
Attributes¶
Classes¶
Abstract base class for model training. |
Module Contents¶
- altametris.sara.core.base_trainer.logger¶
- class altametris.sara.core.base_trainer.BaseTrainer(model: altametris.sara.core.base_model.BaseModel, device: str = 'auto', callbacks: list[altametris.sara.core.base_callback.BaseCallback] | None = None)¶
Bases:
abc.ABCAbstract base class for model training.
Provides common training infrastructure: - Callback management for training lifecycle events - Device management - Training state tracking - Configuration validation
- Parameters:
model – Model to train
device – Device to train on (“cpu”, “cuda”, “mps”, “auto”)
callbacks – List of callbacks for training events
Example
>>> model = MyModel() >>> trainer = MyTrainer(model=model, device="cuda") >>> trainer.train(dataset_config="data.yaml", epochs=100)
- model¶
- device = 'auto'¶
- callback_manager¶
- _is_training = False¶
- _current_epoch = 0¶
- abstract train(dataset_config: Any, epochs: int = 100, **kwargs: Any) dict[str, Any]¶
Train the model.
- Parameters:
dataset_config – Dataset configuration (path, dict, etc.)
epochs – Number of training epochs
**kwargs – Additional training arguments
- Returns:
Training results and metrics
- Raises:
TrainingError – If training fails
Note
Must be implemented by subclasses
- abstract validate(dataset_config: Any, **kwargs: Any) dict[str, Any]¶
Validate the model.
- Parameters:
dataset_config – Validation dataset configuration
**kwargs – Additional validation arguments
- Returns:
Validation metrics
- Raises:
TrainingError – If validation fails
Note
Must be implemented by subclasses
- add_callback(callback: altametris.sara.core.base_callback.BaseCallback) None¶
Add a callback to the trainer.
- Parameters:
callback – Callback instance
Example
>>> trainer.add_callback(MyCallback())
- remove_callback(callback: altametris.sara.core.base_callback.BaseCallback) None¶
Remove a callback from the trainer.
- Parameters:
callback – Callback instance to remove
- property is_training: bool¶
Check if currently training.
- property current_epoch: int¶
Get current epoch number.
- _validate_dataset_config(config: Any) None¶
Validate dataset configuration.
- Parameters:
config – Dataset configuration to validate
- Raises:
TrainingError – If configuration is invalid
- _validate_training_args(epochs: int, **kwargs: Any) None¶
Validate training arguments.
- Parameters:
epochs – Number of epochs
**kwargs – Additional arguments
- Raises:
TrainingError – If arguments are invalid
- save_checkpoint(path: pathlib.Path, epoch: int, metrics: dict[str, Any]) None¶
Save training checkpoint.
- Parameters:
path – Path to save checkpoint
epoch – Current epoch
metrics – Current metrics
Example
>>> trainer.save_checkpoint( ... path=Path("checkpoints/epoch_10.pt"), ... epoch=10, ... metrics={"loss": 0.5} ... )
- load_checkpoint(path: pathlib.Path) dict[str, Any]¶
Load training checkpoint.
- Parameters:
path – Path to checkpoint file
- Returns:
Checkpoint data
- Raises:
TrainingError – If checkpoint cannot be loaded
- __repr__() str¶
String representation of trainer.