import os.path as osp
from typing import Any, List, Optional, Tuple, Union
from ..data.evaluation.base_metric import BaseMetric
from ..data.structures.list_data import ListData
from ..learning import ABLModel
from ..reasoning import A3BLReasoner
from ..utils import print_log
from .simple_bridge import SimpleBridge
[docs]
class A3BLBridge(SimpleBridge):
"""
An ambiguity-aware implementation for bridging machine learning and reasoning parts.
Reference: https://github.com/Hao-Yuan-He/A3BL
Involves the following five steps:
- Predict class probabilities and indices for the given data examples.
- Map indices into pseudo-labels.
- Enumerate all valid pseudo-labels.
- Revise pseudo-labels to label distribution based on the class probabilities.
- Train the model.
Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, used for prediction
and training. The wrapped base model should expose ``extract_features``
so embeddings are available for the soft-label aggregation.
reasoner : A3BLReasoner
The reasoning part wrapped in ``A3BLReasoner``, used for pseudo-label
enumeration and soft-label aggregation.
metric_list : List[BaseMetric]
A list of metrics used for evaluating the model's performance.
"""
def __init__(
self,
model: ABLModel,
reasoner: A3BLReasoner,
metric_list: List[BaseMetric],
):
super().__init__(model, reasoner, metric_list)
[docs]
def abduce_soft_label(self, data_examples: ListData) -> List[List[Any]]:
"""
Revise predicted pseudo-labels to a soft label, given data examples using abduction.
Parameters
----------
data_examples : ListData
Data examples containing predicted pseudo-labels.
Returns
-------
List[List[Any]]
A list of abduced soft labels for the given data examples.
"""
self.reasoner.batch_abduce(data_examples)
return data_examples.abduced_soft_label
[docs]
def train_data_iter(
self,
train_data,
val_data=None,
segment_size=1.0,
):
data_examples = self.data_preprocess("train", train_data)
if val_data is not None:
val_data_examples = self.data_preprocess("val", val_data)
else:
val_data_examples = data_examples
if isinstance(segment_size, int):
if segment_size <= 0:
raise ValueError("segment_size should be positive.")
elif isinstance(segment_size, float):
if 0 < segment_size <= 1:
segment_size = int(segment_size * len(data_examples))
else:
raise ValueError("segment_size should be in (0, 1].")
else:
raise ValueError("segment_size should be int or float.")
for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
sub = data_examples[seg_idx * segment_size: (seg_idx + 1) * segment_size]
yield sub, val_data_examples
[docs]
def train(
self,
train_data: Union[
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]
],
val_data: Optional[
Union[
ListData,
Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
]
] = None,
loops: int = 50,
segment_size: Union[int, float] = 1.0,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
"""
A typical training pipeline of Abuductive Learning.
Parameters
----------
train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
- ``X`` is a list of sublists representing the input data.
- ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but
not to train. ``gt_pseudo_label`` can be ``None``.
- ``Y`` is a list representing the ground truth reasoning result for each sublist
in ``X``.
label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional
Labeled data should be in the same format as ``train_data``. The only difference is
that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
utilized to train the model. Defaults to None.
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long
Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
and ``Y`` can be either None or not, which depends on the evaluation metircs in
``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
the model during training time. Defaults to None.
loops : int
Learning part and Reasoning part will be iteratively optimized
for ``loops`` times. Defaults to 50.
segment_size : Union[int, float]
Data will be split into segments of this size and data in each segment
will be used together to train the model. Defaults to 1.0.
eval_interval : int
The model will be evaluated every ``eval_interval`` loop during training,
Defaults to 1.
save_interval : int, optional
The model will be saved every ``eval_interval`` loop during training.
Defaults to None.
save_dir : str, optional
Directory to save the model. Defaults to None.
"""
for loop in range(loops):
iterator = self.train_data_iter(train_data, val_data, segment_size)
for train_examples_batch, val_examples_batch in iterator:
print_log(
f"loop(train) [{loop + 1}/{loops}] segment(train) ", logger="current"
)
self.predict(train_examples_batch)
self.idx_to_pseudo_label(train_examples_batch)
self.abduce_pseudo_label(train_examples_batch)
self.filter_pseudo_label(train_examples_batch)
self.pseudo_label_to_idx(train_examples_batch)
self.model.train(train_examples_batch)
if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
self._valid(val_examples_batch)
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
self.model.save(
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
)