"""
This module contains the base class for the Bridge part.
Copyright (c) 2024 LAMDA. All rights reserved.
"""
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Tuple, Union
from ..data.structures import ListData
from ..learning import ABLModel
from ..reasoning import Reasoner
[docs]
class BaseBridge(metaclass=ABCMeta):
"""
A base class for bridging learning and reasoning parts.
This class provides necessary methods that need to be overridden in subclasses
to construct a typical pipeline of Abductive Learning (corresponding to ``train``),
which involves the following four methods:
- predict: Predict class indices on the given data examples.
- idx_to_pseudo_label: Map indices into pseudo-labels.
- abduce_pseudo_label: Revise pseudo-labels based on abdutive reasoning.
- pseudo_label_to_idx: Map revised pseudo-labels back into indices.
Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
"""
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(f"Expected an instance of ABLModel, but received type: {type(model)}")
if not (hasattr(reasoner, "idx_to_label") and hasattr(reasoner, "label_to_idx")):
raise TypeError(
"Expected a reasoner exposing idx_to_label / label_to_idx (e.g. Reasoner "
f"or VerificationReasoner), but received type: {type(reasoner)}"
)
self.model = model
self.reasoner = reasoner
[docs]
@abstractmethod
def predict(self, data_examples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predicting class indices from input."""
[docs]
@abstractmethod
def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""Placeholder for revising pseudo-labels based on abdutive reasoning."""
[docs]
@abstractmethod
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""Placeholder for mapping indices to pseudo-labels."""
[docs]
@abstractmethod
def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
"""Placeholder for mapping pseudo-labels to indices."""
[docs]
def filter_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""Default filter function for pseudo-label."""
non_empty_idx = [
i
for i in range(len(data_examples.abduced_pseudo_label))
if data_examples.abduced_pseudo_label[i]
]
data_examples.update(data_examples[non_empty_idx])
return data_examples
[docs]
@abstractmethod
def train(
self,
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
):
"""Placeholder for training loop of ABductive Learning."""
[docs]
@abstractmethod
def valid(
self,
val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""Placeholder for model test."""
[docs]
@abstractmethod
def test(
self,
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""Placeholder for model validation."""