altametris.sara.core

Core abstractions partagées par toutes les bibliothèques ML.

Submodules

Exceptions

AltametrisException

Base exception for all Altametris libraries.

ConfigurationError

Exception raised for configuration-related errors.

ExportError

Exception raised for export-related errors.

InferenceError

Exception raised for inference-related errors.

ModelError

Exception raised for model-related errors.

TrainingError

Exception raised for training-related errors.

Classes

BaseCallback

Base class for training callbacks.

CallbackManager

Manages multiple callbacks and orchestrates their execution.

BaseDetector

Abstract base class for model inference/detection.

BaseExporter

Abstract base class for model export.

BaseModel

Abstract base class for all ML models.

BaseTrainer

Abstract base class for model training.

Package Contents

class altametris.sara.core.BaseCallback

Bases: abc.ABC

Base class for training callbacks.

Callbacks provide hooks into the training lifecycle, allowing custom behavior at key points without modifying training code. All hook methods are optional - override only the ones you need.

Hook execution order:

on_train_start()
-> for each epoch:
    -> on_epoch_start()
    -> for each batch:
        -> on_batch_start()
        -> on_batch_end()
    -> on_validation_start()
    -> on_validation_end()
    -> on_epoch_end()
on_train_end()

Example

>>> class LoggingCallback(BaseCallback):
...     def on_epoch_end(self, epoch, metrics):
...         print(f"Epoch {epoch}: {metrics}")
...
>>> callback = LoggingCallback()
>>> callback.on_epoch_end(epoch=1, metrics={"loss": 0.5})
on_train_start(**kwargs: Any) None

Called at the start of training.

Parameters:

**kwargs – Additional context (model, config, etc.)

Example

>>> def on_train_start(self, model, epochs, **kwargs):
...     print(f"Starting training for {epochs} epochs")
on_train_end(**kwargs: Any) None

Called at the end of training.

Parameters:

**kwargs – Additional context (final metrics, model path, etc.)

Example

>>> def on_train_end(self, metrics, **kwargs):
...     print(f"Training complete: {metrics}")
on_epoch_start(epoch: int, **kwargs: Any) None

Called at the start of each epoch.

Parameters:
  • epoch – Current epoch number (0-indexed)

  • **kwargs – Additional context

Example

>>> def on_epoch_start(self, epoch, **kwargs):
...     print(f"Starting epoch {epoch}")
on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None

Called at the end of each epoch.

Parameters:
  • epoch – Current epoch number (0-indexed)

  • metrics – Metrics computed during epoch (loss, mAP, etc.)

  • **kwargs – Additional context

Example

>>> def on_epoch_end(self, epoch, metrics, **kwargs):
...     if metrics["loss"] < 0.1:
...         print("Loss threshold reached!")
on_batch_start(batch: int, **kwargs: Any) None

Called at the start of each batch.

Parameters:
  • batch – Current batch number (0-indexed within epoch)

  • **kwargs – Additional context (batch data, etc.)

Example

>>> def on_batch_start(self, batch, **kwargs):
...     if batch % 100 == 0:
...         print(f"Processing batch {batch}")
on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None

Called at the end of each batch.

Parameters:
  • batch – Current batch number (0-indexed within epoch)

  • metrics – Batch metrics (batch_loss, etc.)

  • **kwargs – Additional context

Example

>>> def on_batch_end(self, batch, metrics, **kwargs):
...     if metrics["batch_loss"] > 10.0:
...         raise TrainingError("Loss explosion detected")
on_validation_start(**kwargs: Any) None

Called at the start of validation.

Parameters:

**kwargs – Additional context

Example

>>> def on_validation_start(self, **kwargs):
...     print("Starting validation...")
on_validation_end(metrics: dict[str, Any], **kwargs: Any) None

Called at the end of validation.

Parameters:
  • metrics – Validation metrics (val_loss, mAP, etc.)

  • **kwargs – Additional context

Example

>>> def on_validation_end(self, metrics, **kwargs):
...     print(f"Validation mAP: {metrics.get('mAP')}")
class altametris.sara.core.CallbackManager(callbacks: list[BaseCallback] | None = None)

Manages multiple callbacks and orchestrates their execution.

The manager ensures all callbacks receive lifecycle events in order, handles exceptions gracefully, and provides logging for debugging.

Parameters:

callbacks – List of callback instances to manage

Example

>>> callback1 = LoggingCallback()
>>> callback2 = MetricsCallback()
>>> manager = CallbackManager([callback1, callback2])
>>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5})
callbacks = []
_validate_callbacks() None

Validate that all callbacks are BaseCallback instances.

Raises:

TrainingError – If any callback is not a BaseCallback instance

add_callback(callback: BaseCallback) None

Add a callback to the manager.

Parameters:

callback – Callback instance to add

Raises:

TrainingError – If callback is not a BaseCallback instance

Example

>>> manager = CallbackManager()
>>> manager.add_callback(LoggingCallback())
remove_callback(callback: BaseCallback) None

Remove a callback from the manager.

Parameters:

callback – Callback instance to remove

Example

>>> manager.remove_callback(callback1)
_execute_callback_hook(hook_name: str, *args: Any, **kwargs: Any) None

Execute a specific hook on all callbacks.

Parameters:
  • hook_name – Name of the hook method to call

  • *args – Positional arguments for the hook

  • **kwargs – Keyword arguments for the hook

Raises:

TrainingError – If any callback raises an error (re-raised with context)

on_train_start(**kwargs: Any) None

Execute on_train_start hook on all callbacks.

on_train_end(**kwargs: Any) None

Execute on_train_end hook on all callbacks.

on_epoch_start(epoch: int, **kwargs: Any) None

Execute on_epoch_start hook on all callbacks.

on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None

Execute on_epoch_end hook on all callbacks.

on_batch_start(batch: int, **kwargs: Any) None

Execute on_batch_start hook on all callbacks.

on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None

Execute on_batch_end hook on all callbacks.

on_validation_start(**kwargs: Any) None

Execute on_validation_start hook on all callbacks.

on_validation_end(metrics: dict[str, Any], **kwargs: Any) None

Execute on_validation_end hook on all callbacks.

__len__() int

Return number of registered callbacks.

__repr__() str

Return string representation of callback manager.

class altametris.sara.core.BaseDetector(model_path: str | pathlib.Path, device: str = 'auto', warmup: bool = False, **kwargs: Any)

Bases: abc.ABC

Abstract base class for model inference/detection.

Provides common inference infrastructure: - Model loading and initialization - Device management - Warmup capability - Prediction interface

Parameters:
  • model_path – Path to model weights

  • device – Device for inference (“cpu”, “cuda”, “mps”, “auto”)

  • warmup – Whether to run warmup inference on initialization

Example

>>> detector = MyDetector(model_path="weights/best.pt", device="cuda")
>>> results = detector.predict(source="image.jpg")
model_path
_device
model = None
_is_initialized = False
_resolve_device(device: str) torch.device

Resolve device string to torch.device.

Parameters:

device – Device string

Returns:

Resolved torch.device

property device: torch.device

Get current device.

property is_initialized: bool

Check if detector is initialized.

abstract _load_model(model_path: pathlib.Path, **kwargs: Any) None

Load model from path.

Parameters:
  • model_path – Path to model file

  • **kwargs – Additional loading arguments

Raises:

ModelError – If model cannot be loaded

Note

Must be implemented by subclasses. Should set self.model.

abstract predict(source: Any, **kwargs: Any) Any

Run inference on source.

Parameters:
  • source – Input source (image path, array, video, etc.)

  • **kwargs – Inference parameters (conf, iou, etc.)

Returns:

Prediction results (format depends on detector type)

Raises:

InferenceError – If prediction fails

Note

Must be implemented by subclasses

warmup(iterations: int = 3) None

Warmup the model with dummy inference.

Useful for GPU models to pre-allocate memory and compile kernels.

Parameters:

iterations – Number of warmup iterations

Example

>>> detector.warmup(iterations=5)
validate_source(source: Any) None

Validate input source.

Parameters:

source – Input source to validate

Raises:

InferenceError – If source is invalid

__call__(source: Any, **kwargs: Any) Any

Callable interface for prediction.

Parameters:
  • source – Input source

  • **kwargs – Inference parameters

Returns:

Prediction results

Example

>>> detector = MyDetector(model_path="weights/best.pt")
>>> results = detector("image.jpg", conf=0.5)
__repr__() str

String representation of detector.

class altametris.sara.core.BaseExporter(model_path: str | pathlib.Path, output_dir: str | pathlib.Path | None = None)

Bases: abc.ABC

Abstract base class for model export.

Provides common export infrastructure: - Multi-format export support - Export validation - Path management - Format-specific configuration

Parameters:
  • model_path – Path to source model

  • output_dir – Directory for exported models

Example

>>> exporter = MyExporter(model_path="weights/best.pt")
>>> onnx_path = exporter.export(format="onnx", imgsz=640)
>>> exporter.validate_export(onnx_path)
SUPPORTED_FORMATS = ['onnx', 'torchscript', 'tensorrt', 'coreml', 'tflite']
model_path
output_dir
abstract export(format: str = 'onnx', **kwargs: Any) pathlib.Path

Export model to specified format.

Parameters:
  • format – Export format (onnx, tensorrt, etc.)

  • **kwargs – Format-specific export arguments

Returns:

Path to exported model

Raises:

ExportError – If export fails

Note

Must be implemented by subclasses

abstract validate_export(export_path: pathlib.Path, **kwargs: Any) bool

Validate exported model.

Parameters:
  • export_path – Path to exported model

  • **kwargs – Validation arguments

Returns:

True if validation passes

Raises:

ExportError – If validation fails

Note

Must be implemented by subclasses

validate_format(format: str) None

Validate export format is supported.

Parameters:

format – Format string to validate

Raises:

ExportError – If format is not supported

get_export_path(format: str, suffix: str | None = None) pathlib.Path

Generate export file path.

Parameters:
  • format – Export format

  • suffix – Optional suffix before extension

Returns:

Path for exported model

Example

>>> exporter.get_export_path("onnx")
Path("weights/best.onnx")
>>> exporter.get_export_path("onnx", suffix="_fp16")
Path("weights/best_fp16.onnx")
cleanup_export(export_path: pathlib.Path) None

Clean up temporary export files.

Parameters:

export_path – Path to export file to remove

Example

>>> exporter.cleanup_export(Path("weights/temp.onnx"))
export_all(formats: list[str] | None = None, **kwargs: Any) dict[str, pathlib.Path | None]

Export model to multiple formats.

Parameters:
  • formats – List of formats to export (default: all supported)

  • **kwargs – Export arguments

Returns:

Dictionary mapping format to export path (None if export failed)

Example

>>> exports = exporter.export_all(formats=["onnx", "tensorrt"])
>>> print(exports["onnx"])
__repr__() str

String representation of exporter.

class altametris.sara.core.BaseModel(config: dict[str, Any] | None = None, device: str = 'auto')

Bases: abc.ABC, torch.nn.Module

Abstract base class for all ML models.

Provides common functionality for model lifecycle: - Model initialization and configuration - Weight loading and saving - Device management (CPU, CUDA, MPS) - Forward pass abstraction

All concrete model implementations must inherit from this class and implement the abstract methods.

Parameters:
  • config – Model configuration dictionary

  • device – Device to run model on (“cpu”, “cuda”, “mps”, or “auto”)

Example

>>> class YoloModel(BaseModel):
...     def forward(self, x):
...         return self.model(x)
...     def load_weights(self, path):
...         self.model = YOLO(path)
...     def save_weights(self, path):
...         self.model.save(path)
config
_device
_is_initialized = False
_resolve_device(device: str) torch.device

Resolve device string to torch.device.

Parameters:

device – Device string (“cpu”, “cuda”, “mps”, “auto”)

Returns:

Resolved torch.device

Raises:

ModelError – If device is invalid or not available

property device: torch.device

Get current device.

to_device(device: str | torch.device) BaseModel

Move model to specified device.

Parameters:

device – Target device

Returns:

Self for chaining

Example

>>> model.to_device("cuda")
validate_config(config: dict[str, Any]) None

Validate model configuration.

Parameters:

config – Configuration to validate

Raises:

ConfigurationError – If configuration is invalid

Example

>>> model.validate_config({"input_size": 640})
abstract forward(x: torch.Tensor, **kwargs: Any) Any

Forward pass through the model.

Parameters:
  • x – Input tensor

  • **kwargs – Additional arguments

Returns:

Model output

Note

Must be implemented by subclasses

abstract load_weights(path: str | pathlib.Path, **kwargs: Any) None

Load model weights from file.

Parameters:
  • path – Path to weights file

  • **kwargs – Additional loading arguments

Raises:

ModelError – If weights cannot be loaded

Note

Must be implemented by subclasses

abstract save_weights(path: str | pathlib.Path, **kwargs: Any) None

Save model weights to file.

Parameters:
  • path – Path to save weights

  • **kwargs – Additional saving arguments

Raises:

ModelError – If weights cannot be saved

Note

Must be implemented by subclasses

get_num_parameters() int

Get total number of model parameters.

Returns:

Total parameter count

Example

>>> num_params = model.get_num_parameters()
>>> print(f"Model has {num_params:,} parameters")
get_trainable_parameters() int

Get number of trainable parameters.

Returns:

Trainable parameter count

freeze() None

Freeze all model parameters (set requires_grad=False).

Example

>>> model.freeze()
>>> # Model parameters won't be updated during training
unfreeze() None

Unfreeze all model parameters (set requires_grad=True).

Example

>>> model.unfreeze()
>>> # Model parameters will be updated during training
summary() dict[str, Any]

Get model summary with key statistics.

Returns:

Dictionary with model info

Example

>>> summary = model.summary()
>>> print(f"Parameters: {summary['total_parameters']}")
__repr__() str

String representation of model.

class altametris.sara.core.BaseTrainer(model: altametris.sara.core.base_model.BaseModel, device: str = 'auto', callbacks: list[altametris.sara.core.base_callback.BaseCallback] | None = None)

Bases: abc.ABC

Abstract base class for model training.

Provides common training infrastructure: - Callback management for training lifecycle events - Device management - Training state tracking - Configuration validation

Parameters:
  • model – Model to train

  • device – Device to train on (“cpu”, “cuda”, “mps”, “auto”)

  • callbacks – List of callbacks for training events

Example

>>> model = MyModel()
>>> trainer = MyTrainer(model=model, device="cuda")
>>> trainer.train(dataset_config="data.yaml", epochs=100)
model
device = 'auto'
callback_manager
_is_training = False
_current_epoch = 0
abstract train(dataset_config: Any, epochs: int = 100, **kwargs: Any) dict[str, Any]

Train the model.

Parameters:
  • dataset_config – Dataset configuration (path, dict, etc.)

  • epochs – Number of training epochs

  • **kwargs – Additional training arguments

Returns:

Training results and metrics

Raises:

TrainingError – If training fails

Note

Must be implemented by subclasses

abstract validate(dataset_config: Any, **kwargs: Any) dict[str, Any]

Validate the model.

Parameters:
  • dataset_config – Validation dataset configuration

  • **kwargs – Additional validation arguments

Returns:

Validation metrics

Raises:

TrainingError – If validation fails

Note

Must be implemented by subclasses

add_callback(callback: altametris.sara.core.base_callback.BaseCallback) None

Add a callback to the trainer.

Parameters:

callback – Callback instance

Example

>>> trainer.add_callback(MyCallback())
remove_callback(callback: altametris.sara.core.base_callback.BaseCallback) None

Remove a callback from the trainer.

Parameters:

callback – Callback instance to remove

property is_training: bool

Check if currently training.

property current_epoch: int

Get current epoch number.

_validate_dataset_config(config: Any) None

Validate dataset configuration.

Parameters:

config – Dataset configuration to validate

Raises:

TrainingError – If configuration is invalid

_validate_training_args(epochs: int, **kwargs: Any) None

Validate training arguments.

Parameters:
  • epochs – Number of epochs

  • **kwargs – Additional arguments

Raises:

TrainingError – If arguments are invalid

save_checkpoint(path: pathlib.Path, epoch: int, metrics: dict[str, Any]) None

Save training checkpoint.

Parameters:
  • path – Path to save checkpoint

  • epoch – Current epoch

  • metrics – Current metrics

Example

>>> trainer.save_checkpoint(
...     path=Path("checkpoints/epoch_10.pt"),
...     epoch=10,
...     metrics={"loss": 0.5}
... )
load_checkpoint(path: pathlib.Path) dict[str, Any]

Load training checkpoint.

Parameters:

path – Path to checkpoint file

Returns:

Checkpoint data

Raises:

TrainingError – If checkpoint cannot be loaded

__repr__() str

String representation of trainer.

exception altametris.sara.core.AltametrisException(message: str, details: dict[str, Any] | None = None)

Bases: Exception

Base exception for all Altametris libraries.

All custom exceptions inherit from this base class to allow catching all Altametris-specific errors with a single except clause.

Parameters:
  • message – Error message describing what went wrong

  • details – Optional dictionary with additional error context

Example

>>> try:
...     raise AltametrisException("Something went wrong", {"code": 500})
... except AltametrisException as e:
...     print(e)
...     print(e.details)
message
details
__str__() str

Return string representation of exception.

__repr__() str

Return repr representation of exception.

exception altametris.sara.core.ConfigurationError(message: str, details: dict[str, Any] | None = None)

Bases: AltametrisException

Exception raised for configuration-related errors.

Scenarios: - Invalid configuration parameter - Missing required configuration key - Configuration validation failed - Incompatible configuration values - Invalid YAML/JSON format

Example

>>> raise ConfigurationError(
...     "Invalid batch_size",
...     {"parameter": "batch_size", "value": -1, "expected": "> 0"}
... )
exception altametris.sara.core.ExportError(message: str, details: dict[str, Any] | None = None)

Bases: AltametrisException

Exception raised for export-related errors.

Scenarios: - Unsupported export format - Export validation failed - Missing export dependencies (onnx, tensorrt) - Dynamic shapes not supported - Export optimization failed

Example

>>> raise ExportError(
...     "ONNX export failed",
...     {"format": "onnx", "reason": "Dynamic axes not supported"}
... )
exception altametris.sara.core.InferenceError(message: str, details: dict[str, Any] | None = None)

Bases: AltametrisException

Exception raised for inference-related errors.

Scenarios: - Invalid input format (wrong shape, type) - Device mismatch (model on CUDA, input on CPU) - Prediction failure - Post-processing error - Batch size too large for memory

Example

>>> raise InferenceError(
...     "Invalid input shape",
...     {"expected": (3, 640, 640), "got": (640, 640)}
... )
exception altametris.sara.core.ModelError(message: str, details: dict[str, Any] | None = None)

Bases: AltametrisException

Exception raised for model-related errors.

Scenarios: - Failed to load model weights - Invalid model architecture - Corrupted checkpoint file - Incompatible model version - Missing required model files

Example

>>> raise ModelError(
...     "Failed to load weights",
...     {"path": "model.pt", "reason": "File not found"}
... )
exception altametris.sara.core.TrainingError(message: str, details: dict[str, Any] | None = None)

Bases: AltametrisException

Exception raised for training-related errors.

Scenarios: - Dataset not found or invalid format - Training divergence (NaN loss) - Out of memory during training - Invalid hyperparameters - Checkpoint save failure

Example

>>> raise TrainingError(
...     "Training diverged",
...     {"epoch": 10, "loss": float('nan'), "lr": 0.001}
... )