Source code for ablkit.learning.abl_model

"""
This module contains the class ABLModel, which provides a unified interface for different
machine learning models.

Copyright (c) 2024 LAMDA.  All rights reserved.
"""

import pickle
from typing import Any, Dict

import numpy as np

from ..data.structures import ListData
from ..utils import reform_list


[docs] class ABLModel: """ Serialize data and provide a unified interface for different machine learning models. Parameters ---------- base_model : Machine Learning Model The machine learning base model used for training and prediction. This model should implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for the model to also implement ``predict_proba`` (used to populate ``pred_prob``) and ``extract_features`` (used to populate ``data_example.embeddings`` for distance functions such as ``similarity``). """ def __init__(self, base_model: Any) -> None: if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): raise NotImplementedError("The base_model should implement fit and predict methods.") self.base_model = base_model
[docs] def predict(self, data_examples: ListData) -> Dict[str, Any]: """ Predict the labels and probabilities for the given data. Parameters ---------- data_examples : ListData A batch of data to predict on. Returns ------- dict A dictionary containing the predicted labels and probabilities. """ model = self.base_model data_X = data_examples.flatten("X") embeddings = None if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) if hasattr(model, "extract_features"): try: embeddings = model.extract_features(X=data_X) except AttributeError: embeddings = None label = prob.argmax(axis=1) prob = reform_list(prob, data_examples.X) else: prob = None label = model.predict(X=data_X) label = reform_list(label, data_examples.X) data_examples.pred_idx = label data_examples.pred_prob = prob if embeddings is not None: data_examples.embeddings = reform_list(embeddings, data_examples.X) return {"label": label, "prob": prob}
[docs] def train(self, data_examples: ListData) -> float: """ Train the model on the given data. Parameters ---------- data_examples : ListData A batch of data to train on, which typically contains the data, ``X``, and the corresponding labels, ``abduced_idx``. Returns ------- float The loss value of the trained model. """ data_X = data_examples.flatten("X") data_y = data_examples.flatten("abduced_idx") return self.base_model.fit(X=data_X, y=data_y)
[docs] def valid(self, data_examples: ListData) -> float: """ Validate the model on the given data. Parameters ---------- data_examples : ListData A batch of data to train on, which typically contains the data, ``X``, and the corresponding labels, ``abduced_idx``. Returns ------- float The accuracy of the trained model. """ data_X = data_examples.flatten("X") data_y = data_examples.flatten("abduced_idx") score = self.base_model.score(X=data_X, y=data_y) return score
def _model_operation(self, operation: str, *args, **kwargs): model = self.base_model if hasattr(model, operation): method = getattr(model, operation) method(*args, **kwargs) else: if f"{operation}_path" not in kwargs: raise ValueError(f"'{operation}_path' should not be None") try: if operation == "save": with open(kwargs["save_path"], "wb") as file: pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) elif operation == "load": with open(kwargs["load_path"], "rb") as file: self.base_model = pickle.load(file) except (OSError, pickle.PickleError) as exc: raise NotImplementedError( f"{type(model).__name__} object doesn't have the {operation} method \ and the default pickle-based {operation} method failed." ) from exc
[docs] def save(self, *args, **kwargs) -> None: """ Save the model to a file. This method delegates to the ``save`` method of self.base_model. The arguments passed to this method should match those expected by the ``save`` method of self.base_model. """ self._model_operation("save", *args, **kwargs)
[docs] def load(self, *args, **kwargs) -> None: """ Load the model from a file. This method delegates to the ``load`` method of self.base_model. The arguments passed to this method should match those expected by the ``load`` method of self.base_model. """ self._model_operation("load", *args, **kwargs)
# ============================================================================= # Multi-label variants # ============================================================================= class MultiLabelABLModel(ABLModel): """ Multi-label variant of :class:`ABLModel`. The standard :class:`ABLModel.predict` selects a single class index per instance via ``argmax`` over a softmax distribution. For multi-label settings (each instance can have multiple active labels), this class instead thresholds the per-label sigmoid probabilities at 0.5 and stores the resulting binary indicator vectors on ``pred_idx``. Pair it with :class:`~ablkit.learning.MultiLabelBasicNN` (which provides ``predict_proba`` returning ``(num_samples, num_labels)`` sigmoid probabilities) for the typical multi-label workflow. """ def predict(self, data_examples: ListData) -> Dict[str, Any]: """ Predict per-label binary indicators and per-label probabilities. Parameters ---------- data_examples : ListData A batch of data to predict on. Returns ------- Dict[str, Any] A dictionary with keys ``"label"`` (binary indicator vectors grouped per example) and ``"prob"`` (per-label probabilities grouped per example, or ``None`` if the base model does not expose ``predict_proba``). """ model = self.base_model data_X = data_examples.flatten("X") if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = np.where(prob > 0.5, 1, 0).astype(int) prob = reform_list(prob, data_examples.X) else: prob = None label = model.predict(X=data_X) label = reform_list(label, data_examples.X) data_examples.pred_idx = label data_examples.pred_prob = prob return {"label": label, "prob": prob}