Source code for ablkit.bridge.verification_bridge

"""
Bridge for Verification Learning.

:class:`VerificationBridge` replaces the single-candidate abduction step of
:class:`SimpleBridge` with a top-K enumeration provided by
:class:`~ablkit.reasoning.VerificationReasoner`. For each
segment the bridge trains the model once per top-K candidate, exposing
the model to every assignment that is consistent with the knowledge base.

Reference: https://github.com/VerificationLearning/VerificationLearning
"""

from typing import Any, List, Optional, Tuple, Union

from ..data.evaluation import BaseMetric
from ..data.structures import ListData
from ..learning import ABLModel
from ..reasoning.reasoner import VerificationReasoner
from ..utils import print_log
from .simple_bridge import SimpleBridge


[docs] class VerificationBridge(SimpleBridge): """ Bridge implementing the Verification Learning training loop. Parameters ---------- model : ABLModel Wrapped learning model. reasoner : VerificationReasoner Top-K reasoner. The bridge reads ``reasoner.top_k`` to decide how many training passes to run per segment. metric_list : List[BaseMetric] Evaluation metrics, identical to :class:`SimpleBridge`. """ def __init__( self, model: ABLModel, reasoner: VerificationReasoner, metric_list: List[BaseMetric], ) -> None: if not isinstance(reasoner, VerificationReasoner): raise TypeError( "VerificationBridge requires a VerificationReasoner; " f"got {type(reasoner).__name__}." ) super().__init__(model, reasoner, metric_list)
[docs] def train( self, train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], val_data: Optional[ Union[ ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], ] ] = None, loops: int = 50, segment_size: Union[int, float] = 1.0, eval_interval: int = 1, save_interval: Optional[int] = None, save_dir: Optional[str] = None, ) -> None: """ Verification Learning training loop. For each segment we predict once, enumerate the top-K consistent candidates, then run a ``model.train`` pass per candidate. """ data_examples = self.data_preprocess("train", train_data) val_data_examples = ( self.data_preprocess("val", val_data) if val_data is not None else data_examples ) segment_size = self._resolve_segment_size(segment_size, len(data_examples)) num_segments = (len(data_examples) - 1) // segment_size + 1 for loop in range(loops): for seg_idx in range(num_segments): print_log( f"loop(train) [{loop + 1}/{loops}] segment(train) " f"[{seg_idx + 1}/{num_segments}] ", logger="current", ) sub_data_examples = data_examples[ seg_idx * segment_size : (seg_idx + 1) * segment_size ] self._train_one_segment_verification(sub_data_examples) self._maybe_eval(val_data_examples, loop, loops, eval_interval) self._maybe_save(loop, loops, save_interval, save_dir)
def _train_one_segment_verification(self, sub_data_examples: ListData) -> None: """ Predict, enumerate top-K candidates, then train once per candidate. Each example's k-th training pass uses its k-th candidate (or, if the example yielded fewer than k candidates, its last available candidate, repeated). """ self.predict(sub_data_examples) self.idx_to_pseudo_label(sub_data_examples) per_example_candidates = self.reasoner.batch_top_k(sub_data_examples) if not per_example_candidates: return max_k = max(len(cands) for cands in per_example_candidates) for k_idx in range(max_k): sub_data_examples.abduced_pseudo_label = [ cands[min(k_idx, len(cands) - 1)] for cands in per_example_candidates ] self.filter_pseudo_label(sub_data_examples) self.pseudo_label_to_idx(sub_data_examples) if len(sub_data_examples) == 0: continue self.model.train(sub_data_examples)