altametris.sara.core.base_model

Base model class for all ML models.

Provides a common interface for model loading, saving, and configuration across different ML frameworks and architectures.

Example

>>> class MyModel(BaseModel):
...     def forward(self, x):
...         return self.model(x)
...     def load_weights(self, path):
...         self.model.load_state_dict(torch.load(path))

Attributes

Classes

BaseModel

Abstract base class for all ML models.

Module Contents

altametris.sara.core.base_model.logger
class altametris.sara.core.base_model.BaseModel(config: dict[str, Any] | None = None, device: str = 'auto')

Bases: abc.ABC, torch.nn.Module

Abstract base class for all ML models.

Provides common functionality for model lifecycle: - Model initialization and configuration - Weight loading and saving - Device management (CPU, CUDA, MPS) - Forward pass abstraction

All concrete model implementations must inherit from this class and implement the abstract methods.

Parameters:
  • config – Model configuration dictionary

  • device – Device to run model on (“cpu”, “cuda”, “mps”, or “auto”)

Example

>>> class YoloModel(BaseModel):
...     def forward(self, x):
...         return self.model(x)
...     def load_weights(self, path):
...         self.model = YOLO(path)
...     def save_weights(self, path):
...         self.model.save(path)
config
_device
_is_initialized = False
_resolve_device(device: str) torch.device

Resolve device string to torch.device.

Parameters:

device – Device string (“cpu”, “cuda”, “mps”, “auto”)

Returns:

Resolved torch.device

Raises:

ModelError – If device is invalid or not available

property device: torch.device

Get current device.

to_device(device: str | torch.device) BaseModel

Move model to specified device.

Parameters:

device – Target device

Returns:

Self for chaining

Example

>>> model.to_device("cuda")
validate_config(config: dict[str, Any]) None

Validate model configuration.

Parameters:

config – Configuration to validate

Raises:

ConfigurationError – If configuration is invalid

Example

>>> model.validate_config({"input_size": 640})
abstract forward(x: torch.Tensor, **kwargs: Any) Any

Forward pass through the model.

Parameters:
  • x – Input tensor

  • **kwargs – Additional arguments

Returns:

Model output

Note

Must be implemented by subclasses

abstract load_weights(path: str | pathlib.Path, **kwargs: Any) None

Load model weights from file.

Parameters:
  • path – Path to weights file

  • **kwargs – Additional loading arguments

Raises:

ModelError – If weights cannot be loaded

Note

Must be implemented by subclasses

abstract save_weights(path: str | pathlib.Path, **kwargs: Any) None

Save model weights to file.

Parameters:
  • path – Path to save weights

  • **kwargs – Additional saving arguments

Raises:

ModelError – If weights cannot be saved

Note

Must be implemented by subclasses

get_num_parameters() int

Get total number of model parameters.

Returns:

Total parameter count

Example

>>> num_params = model.get_num_parameters()
>>> print(f"Model has {num_params:,} parameters")
get_trainable_parameters() int

Get number of trainable parameters.

Returns:

Trainable parameter count

freeze() None

Freeze all model parameters (set requires_grad=False).

Example

>>> model.freeze()
>>> # Model parameters won't be updated during training
unfreeze() None

Unfreeze all model parameters (set requires_grad=True).

Example

>>> model.unfreeze()
>>> # Model parameters will be updated during training
summary() dict[str, Any]

Get model summary with key statistics.

Returns:

Dictionary with model info

Example

>>> summary = model.summary()
>>> print(f"Parameters: {summary['total_parameters']}")
__repr__() str

String representation of model.