altametris.sara.core¶
Core abstractions partagées par toutes les bibliothèques ML.
Submodules¶
Exceptions¶
Base exception for all Altametris libraries. |
|
Exception raised for configuration-related errors. |
|
Exception raised for export-related errors. |
|
Exception raised for inference-related errors. |
|
Exception raised for model-related errors. |
|
Exception raised for training-related errors. |
Classes¶
Base class for training callbacks. |
|
Manages multiple callbacks and orchestrates their execution. |
|
Abstract base class for model inference/detection. |
|
Abstract base class for model export. |
|
Abstract base class for all ML models. |
|
Abstract base class for model training. |
Package Contents¶
- class altametris.sara.core.BaseCallback¶
Bases:
abc.ABCBase class for training callbacks.
Callbacks provide hooks into the training lifecycle, allowing custom behavior at key points without modifying training code. All hook methods are optional - override only the ones you need.
Hook execution order:
on_train_start() -> for each epoch: -> on_epoch_start() -> for each batch: -> on_batch_start() -> on_batch_end() -> on_validation_start() -> on_validation_end() -> on_epoch_end() on_train_end()
Example
>>> class LoggingCallback(BaseCallback): ... def on_epoch_end(self, epoch, metrics): ... print(f"Epoch {epoch}: {metrics}") ... >>> callback = LoggingCallback() >>> callback.on_epoch_end(epoch=1, metrics={"loss": 0.5})
- on_train_start(**kwargs: Any) None¶
Called at the start of training.
- Parameters:
**kwargs – Additional context (model, config, etc.)
Example
>>> def on_train_start(self, model, epochs, **kwargs): ... print(f"Starting training for {epochs} epochs")
- on_train_end(**kwargs: Any) None¶
Called at the end of training.
- Parameters:
**kwargs – Additional context (final metrics, model path, etc.)
Example
>>> def on_train_end(self, metrics, **kwargs): ... print(f"Training complete: {metrics}")
- on_epoch_start(epoch: int, **kwargs: Any) None¶
Called at the start of each epoch.
- Parameters:
epoch – Current epoch number (0-indexed)
**kwargs – Additional context
Example
>>> def on_epoch_start(self, epoch, **kwargs): ... print(f"Starting epoch {epoch}")
- on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Called at the end of each epoch.
- Parameters:
epoch – Current epoch number (0-indexed)
metrics – Metrics computed during epoch (loss, mAP, etc.)
**kwargs – Additional context
Example
>>> def on_epoch_end(self, epoch, metrics, **kwargs): ... if metrics["loss"] < 0.1: ... print("Loss threshold reached!")
- on_batch_start(batch: int, **kwargs: Any) None¶
Called at the start of each batch.
- Parameters:
batch – Current batch number (0-indexed within epoch)
**kwargs – Additional context (batch data, etc.)
Example
>>> def on_batch_start(self, batch, **kwargs): ... if batch % 100 == 0: ... print(f"Processing batch {batch}")
- on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Called at the end of each batch.
- Parameters:
batch – Current batch number (0-indexed within epoch)
metrics – Batch metrics (batch_loss, etc.)
**kwargs – Additional context
Example
>>> def on_batch_end(self, batch, metrics, **kwargs): ... if metrics["batch_loss"] > 10.0: ... raise TrainingError("Loss explosion detected")
- on_validation_start(**kwargs: Any) None¶
Called at the start of validation.
- Parameters:
**kwargs – Additional context
Example
>>> def on_validation_start(self, **kwargs): ... print("Starting validation...")
- on_validation_end(metrics: dict[str, Any], **kwargs: Any) None¶
Called at the end of validation.
- Parameters:
metrics – Validation metrics (val_loss, mAP, etc.)
**kwargs – Additional context
Example
>>> def on_validation_end(self, metrics, **kwargs): ... print(f"Validation mAP: {metrics.get('mAP')}")
- class altametris.sara.core.CallbackManager(callbacks: list[BaseCallback] | None = None)¶
Manages multiple callbacks and orchestrates their execution.
The manager ensures all callbacks receive lifecycle events in order, handles exceptions gracefully, and provides logging for debugging.
- Parameters:
callbacks – List of callback instances to manage
Example
>>> callback1 = LoggingCallback() >>> callback2 = MetricsCallback() >>> manager = CallbackManager([callback1, callback2]) >>> manager.on_epoch_end(epoch=1, metrics={"loss": 0.5})
- callbacks = []¶
- _validate_callbacks() None¶
Validate that all callbacks are BaseCallback instances.
- Raises:
TrainingError – If any callback is not a BaseCallback instance
- add_callback(callback: BaseCallback) None¶
Add a callback to the manager.
- Parameters:
callback – Callback instance to add
- Raises:
TrainingError – If callback is not a BaseCallback instance
Example
>>> manager = CallbackManager() >>> manager.add_callback(LoggingCallback())
- remove_callback(callback: BaseCallback) None¶
Remove a callback from the manager.
- Parameters:
callback – Callback instance to remove
Example
>>> manager.remove_callback(callback1)
- _execute_callback_hook(hook_name: str, *args: Any, **kwargs: Any) None¶
Execute a specific hook on all callbacks.
- Parameters:
hook_name – Name of the hook method to call
*args – Positional arguments for the hook
**kwargs – Keyword arguments for the hook
- Raises:
TrainingError – If any callback raises an error (re-raised with context)
- on_train_start(**kwargs: Any) None¶
Execute on_train_start hook on all callbacks.
- on_train_end(**kwargs: Any) None¶
Execute on_train_end hook on all callbacks.
- on_epoch_start(epoch: int, **kwargs: Any) None¶
Execute on_epoch_start hook on all callbacks.
- on_epoch_end(epoch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Execute on_epoch_end hook on all callbacks.
- on_batch_start(batch: int, **kwargs: Any) None¶
Execute on_batch_start hook on all callbacks.
- on_batch_end(batch: int, metrics: dict[str, Any], **kwargs: Any) None¶
Execute on_batch_end hook on all callbacks.
- on_validation_start(**kwargs: Any) None¶
Execute on_validation_start hook on all callbacks.
- on_validation_end(metrics: dict[str, Any], **kwargs: Any) None¶
Execute on_validation_end hook on all callbacks.
- __len__() int¶
Return number of registered callbacks.
- __repr__() str¶
Return string representation of callback manager.
- class altametris.sara.core.BaseDetector(model_path: str | pathlib.Path, device: str = 'auto', warmup: bool = False, **kwargs: Any)¶
Bases:
abc.ABCAbstract base class for model inference/detection.
Provides common inference infrastructure: - Model loading and initialization - Device management - Warmup capability - Prediction interface
- Parameters:
model_path – Path to model weights
device – Device for inference (“cpu”, “cuda”, “mps”, “auto”)
warmup – Whether to run warmup inference on initialization
Example
>>> detector = MyDetector(model_path="weights/best.pt", device="cuda") >>> results = detector.predict(source="image.jpg")
- model_path¶
- _device¶
- model = None¶
- _is_initialized = False¶
- _resolve_device(device: str) torch.device¶
Resolve device string to torch.device.
- Parameters:
device – Device string
- Returns:
Resolved torch.device
- property device: torch.device¶
Get current device.
- property is_initialized: bool¶
Check if detector is initialized.
- abstract _load_model(model_path: pathlib.Path, **kwargs: Any) None¶
Load model from path.
- Parameters:
model_path – Path to model file
**kwargs – Additional loading arguments
- Raises:
ModelError – If model cannot be loaded
Note
Must be implemented by subclasses. Should set self.model.
- abstract predict(source: Any, **kwargs: Any) Any¶
Run inference on source.
- Parameters:
source – Input source (image path, array, video, etc.)
**kwargs – Inference parameters (conf, iou, etc.)
- Returns:
Prediction results (format depends on detector type)
- Raises:
InferenceError – If prediction fails
Note
Must be implemented by subclasses
- warmup(iterations: int = 3) None¶
Warmup the model with dummy inference.
Useful for GPU models to pre-allocate memory and compile kernels.
- Parameters:
iterations – Number of warmup iterations
Example
>>> detector.warmup(iterations=5)
- validate_source(source: Any) None¶
Validate input source.
- Parameters:
source – Input source to validate
- Raises:
InferenceError – If source is invalid
- __call__(source: Any, **kwargs: Any) Any¶
Callable interface for prediction.
- Parameters:
source – Input source
**kwargs – Inference parameters
- Returns:
Prediction results
Example
>>> detector = MyDetector(model_path="weights/best.pt") >>> results = detector("image.jpg", conf=0.5)
- __repr__() str¶
String representation of detector.
- class altametris.sara.core.BaseExporter(model_path: str | pathlib.Path, output_dir: str | pathlib.Path | None = None)¶
Bases:
abc.ABCAbstract base class for model export.
Provides common export infrastructure: - Multi-format export support - Export validation - Path management - Format-specific configuration
- Parameters:
model_path – Path to source model
output_dir – Directory for exported models
Example
>>> exporter = MyExporter(model_path="weights/best.pt") >>> onnx_path = exporter.export(format="onnx", imgsz=640) >>> exporter.validate_export(onnx_path)
- SUPPORTED_FORMATS = ['onnx', 'torchscript', 'tensorrt', 'coreml', 'tflite']¶
- model_path¶
- output_dir¶
- abstract export(format: str = 'onnx', **kwargs: Any) pathlib.Path¶
Export model to specified format.
- Parameters:
format – Export format (onnx, tensorrt, etc.)
**kwargs – Format-specific export arguments
- Returns:
Path to exported model
- Raises:
ExportError – If export fails
Note
Must be implemented by subclasses
- abstract validate_export(export_path: pathlib.Path, **kwargs: Any) bool¶
Validate exported model.
- Parameters:
export_path – Path to exported model
**kwargs – Validation arguments
- Returns:
True if validation passes
- Raises:
ExportError – If validation fails
Note
Must be implemented by subclasses
- validate_format(format: str) None¶
Validate export format is supported.
- Parameters:
format – Format string to validate
- Raises:
ExportError – If format is not supported
- get_export_path(format: str, suffix: str | None = None) pathlib.Path¶
Generate export file path.
- Parameters:
format – Export format
suffix – Optional suffix before extension
- Returns:
Path for exported model
Example
>>> exporter.get_export_path("onnx") Path("weights/best.onnx") >>> exporter.get_export_path("onnx", suffix="_fp16") Path("weights/best_fp16.onnx")
- cleanup_export(export_path: pathlib.Path) None¶
Clean up temporary export files.
- Parameters:
export_path – Path to export file to remove
Example
>>> exporter.cleanup_export(Path("weights/temp.onnx"))
- export_all(formats: list[str] | None = None, **kwargs: Any) dict[str, pathlib.Path | None]¶
Export model to multiple formats.
- Parameters:
formats – List of formats to export (default: all supported)
**kwargs – Export arguments
- Returns:
Dictionary mapping format to export path (None if export failed)
Example
>>> exports = exporter.export_all(formats=["onnx", "tensorrt"]) >>> print(exports["onnx"])
- __repr__() str¶
String representation of exporter.
- class altametris.sara.core.BaseModel(config: dict[str, Any] | None = None, device: str = 'auto')¶
Bases:
abc.ABC,torch.nn.ModuleAbstract base class for all ML models.
Provides common functionality for model lifecycle: - Model initialization and configuration - Weight loading and saving - Device management (CPU, CUDA, MPS) - Forward pass abstraction
All concrete model implementations must inherit from this class and implement the abstract methods.
- Parameters:
config – Model configuration dictionary
device – Device to run model on (“cpu”, “cuda”, “mps”, or “auto”)
Example
>>> class YoloModel(BaseModel): ... def forward(self, x): ... return self.model(x) ... def load_weights(self, path): ... self.model = YOLO(path) ... def save_weights(self, path): ... self.model.save(path)
- config¶
- _device¶
- _is_initialized = False¶
- _resolve_device(device: str) torch.device¶
Resolve device string to torch.device.
- Parameters:
device – Device string (“cpu”, “cuda”, “mps”, “auto”)
- Returns:
Resolved torch.device
- Raises:
ModelError – If device is invalid or not available
- property device: torch.device¶
Get current device.
- to_device(device: str | torch.device) BaseModel¶
Move model to specified device.
- Parameters:
device – Target device
- Returns:
Self for chaining
Example
>>> model.to_device("cuda")
- validate_config(config: dict[str, Any]) None¶
Validate model configuration.
- Parameters:
config – Configuration to validate
- Raises:
ConfigurationError – If configuration is invalid
Example
>>> model.validate_config({"input_size": 640})
- abstract forward(x: torch.Tensor, **kwargs: Any) Any¶
Forward pass through the model.
- Parameters:
x – Input tensor
**kwargs – Additional arguments
- Returns:
Model output
Note
Must be implemented by subclasses
- abstract load_weights(path: str | pathlib.Path, **kwargs: Any) None¶
Load model weights from file.
- Parameters:
path – Path to weights file
**kwargs – Additional loading arguments
- Raises:
ModelError – If weights cannot be loaded
Note
Must be implemented by subclasses
- abstract save_weights(path: str | pathlib.Path, **kwargs: Any) None¶
Save model weights to file.
- Parameters:
path – Path to save weights
**kwargs – Additional saving arguments
- Raises:
ModelError – If weights cannot be saved
Note
Must be implemented by subclasses
- get_num_parameters() int¶
Get total number of model parameters.
- Returns:
Total parameter count
Example
>>> num_params = model.get_num_parameters() >>> print(f"Model has {num_params:,} parameters")
- get_trainable_parameters() int¶
Get number of trainable parameters.
- Returns:
Trainable parameter count
- freeze() None¶
Freeze all model parameters (set requires_grad=False).
Example
>>> model.freeze() >>> # Model parameters won't be updated during training
- unfreeze() None¶
Unfreeze all model parameters (set requires_grad=True).
Example
>>> model.unfreeze() >>> # Model parameters will be updated during training
- summary() dict[str, Any]¶
Get model summary with key statistics.
- Returns:
Dictionary with model info
Example
>>> summary = model.summary() >>> print(f"Parameters: {summary['total_parameters']}")
- __repr__() str¶
String representation of model.
- class altametris.sara.core.BaseTrainer(model: altametris.sara.core.base_model.BaseModel, device: str = 'auto', callbacks: list[altametris.sara.core.base_callback.BaseCallback] | None = None)¶
Bases:
abc.ABCAbstract base class for model training.
Provides common training infrastructure: - Callback management for training lifecycle events - Device management - Training state tracking - Configuration validation
- Parameters:
model – Model to train
device – Device to train on (“cpu”, “cuda”, “mps”, “auto”)
callbacks – List of callbacks for training events
Example
>>> model = MyModel() >>> trainer = MyTrainer(model=model, device="cuda") >>> trainer.train(dataset_config="data.yaml", epochs=100)
- model¶
- device = 'auto'¶
- callback_manager¶
- _is_training = False¶
- _current_epoch = 0¶
- abstract train(dataset_config: Any, epochs: int = 100, **kwargs: Any) dict[str, Any]¶
Train the model.
- Parameters:
dataset_config – Dataset configuration (path, dict, etc.)
epochs – Number of training epochs
**kwargs – Additional training arguments
- Returns:
Training results and metrics
- Raises:
TrainingError – If training fails
Note
Must be implemented by subclasses
- abstract validate(dataset_config: Any, **kwargs: Any) dict[str, Any]¶
Validate the model.
- Parameters:
dataset_config – Validation dataset configuration
**kwargs – Additional validation arguments
- Returns:
Validation metrics
- Raises:
TrainingError – If validation fails
Note
Must be implemented by subclasses
- add_callback(callback: altametris.sara.core.base_callback.BaseCallback) None¶
Add a callback to the trainer.
- Parameters:
callback – Callback instance
Example
>>> trainer.add_callback(MyCallback())
- remove_callback(callback: altametris.sara.core.base_callback.BaseCallback) None¶
Remove a callback from the trainer.
- Parameters:
callback – Callback instance to remove
- property is_training: bool¶
Check if currently training.
- property current_epoch: int¶
Get current epoch number.
- _validate_dataset_config(config: Any) None¶
Validate dataset configuration.
- Parameters:
config – Dataset configuration to validate
- Raises:
TrainingError – If configuration is invalid
- _validate_training_args(epochs: int, **kwargs: Any) None¶
Validate training arguments.
- Parameters:
epochs – Number of epochs
**kwargs – Additional arguments
- Raises:
TrainingError – If arguments are invalid
- save_checkpoint(path: pathlib.Path, epoch: int, metrics: dict[str, Any]) None¶
Save training checkpoint.
- Parameters:
path – Path to save checkpoint
epoch – Current epoch
metrics – Current metrics
Example
>>> trainer.save_checkpoint( ... path=Path("checkpoints/epoch_10.pt"), ... epoch=10, ... metrics={"loss": 0.5} ... )
- load_checkpoint(path: pathlib.Path) dict[str, Any]¶
Load training checkpoint.
- Parameters:
path – Path to checkpoint file
- Returns:
Checkpoint data
- Raises:
TrainingError – If checkpoint cannot be loaded
- __repr__() str¶
String representation of trainer.
- exception altametris.sara.core.AltametrisException(message: str, details: dict[str, Any] | None = None)¶
Bases:
ExceptionBase exception for all Altametris libraries.
All custom exceptions inherit from this base class to allow catching all Altametris-specific errors with a single except clause.
- Parameters:
message – Error message describing what went wrong
details – Optional dictionary with additional error context
Example
>>> try: ... raise AltametrisException("Something went wrong", {"code": 500}) ... except AltametrisException as e: ... print(e) ... print(e.details)
- message¶
- details¶
- __str__() str¶
Return string representation of exception.
- __repr__() str¶
Return repr representation of exception.
- exception altametris.sara.core.ConfigurationError(message: str, details: dict[str, Any] | None = None)¶
Bases:
AltametrisExceptionException raised for configuration-related errors.
Scenarios: - Invalid configuration parameter - Missing required configuration key - Configuration validation failed - Incompatible configuration values - Invalid YAML/JSON format
Example
>>> raise ConfigurationError( ... "Invalid batch_size", ... {"parameter": "batch_size", "value": -1, "expected": "> 0"} ... )
- exception altametris.sara.core.ExportError(message: str, details: dict[str, Any] | None = None)¶
Bases:
AltametrisExceptionException raised for export-related errors.
Scenarios: - Unsupported export format - Export validation failed - Missing export dependencies (onnx, tensorrt) - Dynamic shapes not supported - Export optimization failed
Example
>>> raise ExportError( ... "ONNX export failed", ... {"format": "onnx", "reason": "Dynamic axes not supported"} ... )
- exception altametris.sara.core.InferenceError(message: str, details: dict[str, Any] | None = None)¶
Bases:
AltametrisExceptionException raised for inference-related errors.
Scenarios: - Invalid input format (wrong shape, type) - Device mismatch (model on CUDA, input on CPU) - Prediction failure - Post-processing error - Batch size too large for memory
Example
>>> raise InferenceError( ... "Invalid input shape", ... {"expected": (3, 640, 640), "got": (640, 640)} ... )
- exception altametris.sara.core.ModelError(message: str, details: dict[str, Any] | None = None)¶
Bases:
AltametrisExceptionException raised for model-related errors.
Scenarios: - Failed to load model weights - Invalid model architecture - Corrupted checkpoint file - Incompatible model version - Missing required model files
Example
>>> raise ModelError( ... "Failed to load weights", ... {"path": "model.pt", "reason": "File not found"} ... )
- exception altametris.sara.core.TrainingError(message: str, details: dict[str, Any] | None = None)¶
Bases:
AltametrisExceptionException raised for training-related errors.
Scenarios: - Dataset not found or invalid format - Training divergence (NaN loss) - Out of memory during training - Invalid hyperparameters - Checkpoint save failure
Example
>>> raise TrainingError( ... "Training diverged", ... {"epoch": 10, "loss": float('nan'), "lr": 0.001} ... )