altametris.sara.core.base_model¶
Base model class for all ML models.
Provides a common interface for model loading, saving, and configuration across different ML frameworks and architectures.
Example
>>> class MyModel(BaseModel):
... def forward(self, x):
... return self.model(x)
... def load_weights(self, path):
... self.model.load_state_dict(torch.load(path))
Attributes¶
Classes¶
Abstract base class for all ML models. |
Module Contents¶
- altametris.sara.core.base_model.logger¶
- class altametris.sara.core.base_model.BaseModel(config: dict[str, Any] | None = None, device: str = 'auto')¶
Bases:
abc.ABC,torch.nn.ModuleAbstract base class for all ML models.
Provides common functionality for model lifecycle: - Model initialization and configuration - Weight loading and saving - Device management (CPU, CUDA, MPS) - Forward pass abstraction
All concrete model implementations must inherit from this class and implement the abstract methods.
- Parameters:
config – Model configuration dictionary
device – Device to run model on (“cpu”, “cuda”, “mps”, or “auto”)
Example
>>> class YoloModel(BaseModel): ... def forward(self, x): ... return self.model(x) ... def load_weights(self, path): ... self.model = YOLO(path) ... def save_weights(self, path): ... self.model.save(path)
- config¶
- _device¶
- _is_initialized = False¶
- _resolve_device(device: str) torch.device¶
Resolve device string to torch.device.
- Parameters:
device – Device string (“cpu”, “cuda”, “mps”, “auto”)
- Returns:
Resolved torch.device
- Raises:
ModelError – If device is invalid or not available
- property device: torch.device¶
Get current device.
- to_device(device: str | torch.device) BaseModel¶
Move model to specified device.
- Parameters:
device – Target device
- Returns:
Self for chaining
Example
>>> model.to_device("cuda")
- validate_config(config: dict[str, Any]) None¶
Validate model configuration.
- Parameters:
config – Configuration to validate
- Raises:
ConfigurationError – If configuration is invalid
Example
>>> model.validate_config({"input_size": 640})
- abstract forward(x: torch.Tensor, **kwargs: Any) Any¶
Forward pass through the model.
- Parameters:
x – Input tensor
**kwargs – Additional arguments
- Returns:
Model output
Note
Must be implemented by subclasses
- abstract load_weights(path: str | pathlib.Path, **kwargs: Any) None¶
Load model weights from file.
- Parameters:
path – Path to weights file
**kwargs – Additional loading arguments
- Raises:
ModelError – If weights cannot be loaded
Note
Must be implemented by subclasses
- abstract save_weights(path: str | pathlib.Path, **kwargs: Any) None¶
Save model weights to file.
- Parameters:
path – Path to save weights
**kwargs – Additional saving arguments
- Raises:
ModelError – If weights cannot be saved
Note
Must be implemented by subclasses
- get_num_parameters() int¶
Get total number of model parameters.
- Returns:
Total parameter count
Example
>>> num_params = model.get_num_parameters() >>> print(f"Model has {num_params:,} parameters")
- get_trainable_parameters() int¶
Get number of trainable parameters.
- Returns:
Trainable parameter count
- freeze() None¶
Freeze all model parameters (set requires_grad=False).
Example
>>> model.freeze() >>> # Model parameters won't be updated during training
- unfreeze() None¶
Unfreeze all model parameters (set requires_grad=True).
Example
>>> model.unfreeze() >>> # Model parameters will be updated during training
- summary() dict[str, Any]¶
Get model summary with key statistics.
- Returns:
Dictionary with model info
Example
>>> summary = model.summary() >>> print(f"Parameters: {summary['total_parameters']}")
- __repr__() str¶
String representation of model.