altametris.sara.core.base_model =============================== .. py:module:: altametris.sara.core.base_model .. autoapi-nested-parse:: Base model class for all ML models. Provides a common interface for model loading, saving, and configuration across different ML frameworks and architectures. .. rubric:: Example >>> class MyModel(BaseModel): ... def forward(self, x): ... return self.model(x) ... def load_weights(self, path): ... self.model.load_state_dict(torch.load(path)) Attributes ---------- .. autoapisummary:: altametris.sara.core.base_model.logger Classes ------- .. autoapisummary:: altametris.sara.core.base_model.BaseModel Module Contents --------------- .. py:data:: logger .. py:class:: BaseModel(config: Optional[dict[str, Any]] = None, device: str = 'auto') Bases: :py:obj:`abc.ABC`, :py:obj:`torch.nn.Module` Abstract base class for all ML models. Provides common functionality for model lifecycle: - Model initialization and configuration - Weight loading and saving - Device management (CPU, CUDA, MPS) - Forward pass abstraction All concrete model implementations must inherit from this class and implement the abstract methods. :param config: Model configuration dictionary :param device: Device to run model on ("cpu", "cuda", "mps", or "auto") .. rubric:: Example >>> class YoloModel(BaseModel): ... def forward(self, x): ... return self.model(x) ... def load_weights(self, path): ... self.model = YOLO(path) ... def save_weights(self, path): ... self.model.save(path) .. py:attribute:: config .. py:attribute:: _device .. py:attribute:: _is_initialized :value: False .. py:method:: _resolve_device(device: str) -> torch.device Resolve device string to torch.device. :param device: Device string ("cpu", "cuda", "mps", "auto") :returns: Resolved torch.device :raises ModelError: If device is invalid or not available .. py:property:: device :type: torch.device Get current device. .. py:method:: to_device(device: Union[str, torch.device]) -> BaseModel Move model to specified device. :param device: Target device :returns: Self for chaining .. rubric:: Example >>> model.to_device("cuda") .. py:method:: validate_config(config: dict[str, Any]) -> None Validate model configuration. :param config: Configuration to validate :raises ConfigurationError: If configuration is invalid .. rubric:: Example >>> model.validate_config({"input_size": 640}) .. py:method:: forward(x: torch.Tensor, **kwargs: Any) -> Any :abstractmethod: Forward pass through the model. :param x: Input tensor :param \*\*kwargs: Additional arguments :returns: Model output .. note:: Must be implemented by subclasses .. py:method:: load_weights(path: Union[str, pathlib.Path], **kwargs: Any) -> None :abstractmethod: Load model weights from file. :param path: Path to weights file :param \*\*kwargs: Additional loading arguments :raises ModelError: If weights cannot be loaded .. note:: Must be implemented by subclasses .. py:method:: save_weights(path: Union[str, pathlib.Path], **kwargs: Any) -> None :abstractmethod: Save model weights to file. :param path: Path to save weights :param \*\*kwargs: Additional saving arguments :raises ModelError: If weights cannot be saved .. note:: Must be implemented by subclasses .. py:method:: get_num_parameters() -> int Get total number of model parameters. :returns: Total parameter count .. rubric:: Example >>> num_params = model.get_num_parameters() >>> print(f"Model has {num_params:,} parameters") .. py:method:: get_trainable_parameters() -> int Get number of trainable parameters. :returns: Trainable parameter count .. py:method:: freeze() -> None Freeze all model parameters (set requires_grad=False). .. rubric:: Example >>> model.freeze() >>> # Model parameters won't be updated during training .. py:method:: unfreeze() -> None Unfreeze all model parameters (set requires_grad=True). .. rubric:: Example >>> model.unfreeze() >>> # Model parameters will be updated during training .. py:method:: summary() -> dict[str, Any] Get model summary with key statistics. :returns: Dictionary with model info .. rubric:: Example >>> summary = model.summary() >>> print(f"Parameters: {summary['total_parameters']}") .. py:method:: __repr__() -> str String representation of model.