altametris.sara.core.base_trainer ================================= .. py:module:: altametris.sara.core.base_trainer .. autoapi-nested-parse:: Base trainer class for model training. Provides a common interface for training ML models with support for: - Training and validation loops - Callback system integration - Device management - Progress tracking .. rubric:: Example >>> class MyTrainer(BaseTrainer): ... def train(self, dataset_config, epochs=100): ... for epoch in range(epochs): ... self.callback_manager.on_epoch_start(epoch=epoch) ... # Training logic ... self.callback_manager.on_epoch_end(epoch=epoch, metrics=metrics) Attributes ---------- .. autoapisummary:: altametris.sara.core.base_trainer.logger Classes ------- .. autoapisummary:: altametris.sara.core.base_trainer.BaseTrainer Module Contents --------------- .. py:data:: logger .. py:class:: BaseTrainer(model: altametris.sara.core.base_model.BaseModel, device: str = 'auto', callbacks: Optional[list[altametris.sara.core.base_callback.BaseCallback]] = None) Bases: :py:obj:`abc.ABC` Abstract base class for model training. Provides common training infrastructure: - Callback management for training lifecycle events - Device management - Training state tracking - Configuration validation :param model: Model to train :param device: Device to train on ("cpu", "cuda", "mps", "auto") :param callbacks: List of callbacks for training events .. rubric:: Example >>> model = MyModel() >>> trainer = MyTrainer(model=model, device="cuda") >>> trainer.train(dataset_config="data.yaml", epochs=100) .. py:attribute:: model .. py:attribute:: device :value: 'auto' .. py:attribute:: callback_manager .. py:attribute:: _is_training :value: False .. py:attribute:: _current_epoch :value: 0 .. py:method:: train(dataset_config: Any, epochs: int = 100, **kwargs: Any) -> dict[str, Any] :abstractmethod: Train the model. :param dataset_config: Dataset configuration (path, dict, etc.) :param epochs: Number of training epochs :param \*\*kwargs: Additional training arguments :returns: Training results and metrics :raises TrainingError: If training fails .. note:: Must be implemented by subclasses .. py:method:: validate(dataset_config: Any, **kwargs: Any) -> dict[str, Any] :abstractmethod: Validate the model. :param dataset_config: Validation dataset configuration :param \*\*kwargs: Additional validation arguments :returns: Validation metrics :raises TrainingError: If validation fails .. note:: Must be implemented by subclasses .. py:method:: add_callback(callback: altametris.sara.core.base_callback.BaseCallback) -> None Add a callback to the trainer. :param callback: Callback instance .. rubric:: Example >>> trainer.add_callback(MyCallback()) .. py:method:: remove_callback(callback: altametris.sara.core.base_callback.BaseCallback) -> None Remove a callback from the trainer. :param callback: Callback instance to remove .. py:property:: is_training :type: bool Check if currently training. .. py:property:: current_epoch :type: int Get current epoch number. .. py:method:: _validate_dataset_config(config: Any) -> None Validate dataset configuration. :param config: Dataset configuration to validate :raises TrainingError: If configuration is invalid .. py:method:: _validate_training_args(epochs: int, **kwargs: Any) -> None Validate training arguments. :param epochs: Number of epochs :param \*\*kwargs: Additional arguments :raises TrainingError: If arguments are invalid .. py:method:: save_checkpoint(path: pathlib.Path, epoch: int, metrics: dict[str, Any]) -> None Save training checkpoint. :param path: Path to save checkpoint :param epoch: Current epoch :param metrics: Current metrics .. rubric:: Example >>> trainer.save_checkpoint( ... path=Path("checkpoints/epoch_10.pt"), ... epoch=10, ... metrics={"loss": 0.5} ... ) .. py:method:: load_checkpoint(path: pathlib.Path) -> dict[str, Any] Load training checkpoint. :param path: Path to checkpoint file :returns: Checkpoint data :raises TrainingError: If checkpoint cannot be loaded .. py:method:: __repr__() -> str String representation of trainer.