"""
This module contains the class Reasoner, which is used for minimizing the inconsistency
between the knowledge base and learning models.
Copyright (c) 2024 LAMDA. All rights reserved.
"""
import heapq
import inspect
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution
from ..data.structures import ListData
from ..reasoning import KBBase
from ..utils.utils import (
avg_confidence_dist,
confidence_dist,
hamming_dist,
rejection_dist,
similarity_dist,
)
[docs]
class Reasoner:
"""
Reasoner for minimizing the inconsistency between the knowledge base and learning models.
Parameters
----------
kb : class KBBase
The knowledge base to be used for reasoning.
dist_func : Union[str, Callable], optional
The distance function used to determine the cost list between each
candidate and the given prediction. The cost is also referred to as a consistency
measure, wherein the candidate with lowest cost is selected as the final
abduced label. It can be either a string representing a predefined distance
function or a callable function. The available predefined distance functions:
'hamming' | 'confidence' | 'avg_confidence' | 'similarity' | 'rejection'.
'hamming' directly calculates the Hamming distance between the predicted
pseudo-label in the data example and each candidate. 'confidence' and
'avg_confidence' calculate the confidence distance between the predicted
probabilities and each candidate, defined as ``1 - product`` and
``1 - average`` of the candidate's per-symbol probabilities respectively.
'similarity' compares candidates against the geometry of the model's
embeddings (requires the base model to expose ``extract_features``;
``ABLModel`` then stores the result on ``data_example.embeddings``).
'rejection' combines confidence distance with a candidate-complexity penalty,
favoring shorter candidates when scores are close.
Alternatively, the callable function should have the signature
``dist_func(data_example, candidates, candidate_idxs, reasoning_results)`` and must
return a cost list. Each element in this cost list should be a numerical value
representing the cost for each candidate, and the list should have the same length
as candidates. Defaults to 'confidence'.
idx_to_label : dict, optional
A mapping from index in the base model to label. If not provided, a default
order-based index to label mapping is created. Defaults to None.
max_revision : Union[int, float], optional
The upper limit on the number of revisions for each data example when
performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the
number of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required
when performing abductive reasoning. Defaults to 0.
use_zoopt : bool, optional
Whether to use ZOOpt library during abductive reasoning. Defaults to False.
"""
def __init__(
self,
kb: KBBase,
dist_func: Union[str, Callable] = "confidence",
idx_to_label: Optional[dict] = None,
max_revision: Union[int, float] = -1,
require_more_revision: int = 0,
use_zoopt: bool = False,
):
self.kb = kb
self._check_valid_dist(dist_func)
self.dist_func = dist_func
self.use_zoopt = use_zoopt
self.max_revision = max_revision
self.require_more_revision = require_more_revision
if idx_to_label is None:
self.idx_to_label = {
index: label for index, label in enumerate(self.kb.pseudo_label_list)
}
else:
self._check_valid_idx_to_label(idx_to_label)
self.idx_to_label = idx_to_label
self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys()))
def _check_valid_dist(self, dist_func):
if isinstance(dist_func, str):
valid = ["hamming", "confidence", "avg_confidence", "similarity", "rejection"]
if dist_func not in valid:
raise NotImplementedError(
f"Valid options for predefined dist_func are {valid}, "
f"but got {dist_func!r}."
)
return
elif callable(dist_func):
params = inspect.signature(dist_func).parameters.values()
if len(params) != 4:
raise ValueError(
"User-defined dist_func must have exactly four parameters, "
+ f"but got {len(params)}."
)
return
else:
raise TypeError(
f"dist_func must be a string or a callable function, but got {type(dist_func)}."
)
def _check_valid_idx_to_label(self, idx_to_label):
if not isinstance(idx_to_label, dict):
raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.")
for key, value in idx_to_label.items():
if not isinstance(key, int):
raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.")
if value not in self.kb.pseudo_label_list:
raise ValueError(
"All values in the idx_to_label must be in the pseudo_label_list, "
+ f"but got {value}."
)
def _get_one_candidate(
self,
data_example: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> List[Any]:
"""
Due to the nondeterminism of abductive reasoning, there could be multiple candidates
satisfying the knowledge base. When this happens, return one candidate that has the
minimum cost. If no candidates are provided, an empty list is returned.
Parameters
----------
data_example : ListData
Data example.
candidates : List[List[Any]]
Multiple possible candidates.
reasoning_results : List[Any]
Corresponding reasoning results of the candidates.
Returns
-------
List[Any]
A selected candidate.
"""
if len(candidates) == 0:
return []
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(data_example, candidates, reasoning_results)
candidate = candidates[np.argmin(cost_array)]
return candidate
def _get_cost_list(
self,
data_example: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> Union[List[Union[int, float]], np.ndarray]:
"""
Get the list of costs between each candidate and the given data example.
Parameters
----------
data_example : ListData
Data example.
candidates : List[List[Any]]
Multiple possible candidates.
reasoning_results : List[Any]
Corresponding reasoning results of the candidates.
Returns
-------
Union[List[Union[int, float]], np.ndarray]
The list of costs.
"""
if self.dist_func == "hamming":
return hamming_dist(data_example.pred_pseudo_label, candidates)
elif self.dist_func == "confidence":
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return confidence_dist(data_example.pred_prob, candidates_idxs)
elif self.dist_func == "avg_confidence":
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return avg_confidence_dist(data_example.pred_prob, candidates_idxs)
elif self.dist_func == "similarity":
embeddings = getattr(data_example, "embeddings", None)
if embeddings is None:
raise ValueError(
"dist_func='similarity' requires the base model to expose an "
"extract_features(X=...) method so ABLModel can populate "
"data_example.embeddings."
)
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return similarity_dist(embeddings, candidates_idxs=candidates_idxs)
elif self.dist_func == "rejection":
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return rejection_dist(data_example.pred_prob, candidates_idxs=candidates_idxs)
else:
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
"The length of the array returned by dist_func must be equal to the number "
+ f"of candidates. Expected length {len(candidates)}, but got {len(cost_list)}."
)
return cost_list
def _zoopt_get_solution(
self,
symbol_num: int,
data_example: ListData,
max_revision_num: int,
) -> Solution:
"""
Get the optimal solution using ZOOpt library. From the solution, we can get a list of
boolean values, where '1' (True) indicates the indices chosen to be revised.
Parameters
----------
symbol_num : int
Number of total symbols.
data_example : ListData
Data example.
max_revision_num : int
Specifies the maximum number of revisions allowed.
Returns
-------
Solution
The solution for ZOOpt library.
"""
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda sol: self.zoopt_score(symbol_num, data_example, sol),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(
budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True
)
solution = Opt.min(objective, parameter)
return solution
[docs]
def zoopt_score(
self,
symbol_num: int,
data_example: ListData,
sol: Solution,
) -> int:
"""
Set the score for a solution. A lower score suggests that ZOOpt library
has a higher preference for this solution.
Parameters
----------
symbol_num : int
Number of total symbols.
data_example : ListData
Data example.
sol: Solution
The solution for ZOOpt library.
Returns
-------
int
The score for the solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates, reasoning_results = self.kb.revise_at_idx(
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
)
if len(candidates) > 0:
return np.min(self._get_cost_list(data_example, candidates, reasoning_results))
else:
return symbol_num
[docs]
def zoopt_budget(self, symbol_num: int) -> int:
"""
Set the budget for ZOOpt optimization. The budget can be dynamic relying on
the number of symbols considered, e.g., the default implementation shown below.
Alternatively, it can be a fixed value, such as simply setting it to 100.
Parameters
----------
symbol_num : int
The number of symbols to be considered in the ZOOpt optimization process.
Returns
-------
int
The budget for ZOOpt optimization.
"""
return 10 * symbol_num
def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
"""
Constrain that the total number of revisions chosen by the solution does not exceed
maximum number of revisions allowed.
"""
x = solution.get_x()
return max_revision_num - x.sum()
def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
"""
Get the maximum revision number according to input ``max_revision``.
"""
if not isinstance(max_revision, (int, float)):
raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}")
if max_revision == -1:
return symbol_num
if isinstance(max_revision, float):
if not 0 <= max_revision <= 1:
raise ValueError(
"If max_revision is a float, it must be between 0 and 1, "
+ f"but got {max_revision}"
)
return round(symbol_num * max_revision)
if max_revision < 0:
raise ValueError(
f"If max_revision is an int, it must be non-negative, but got {max_revision}"
)
return max_revision
[docs]
def abduce(self, data_example: ListData) -> List[Any]:
"""
Perform abductive reasoning on the given data example.
Parameters
----------
data_example : ListData
Data example.
Returns
-------
List[Any]
A revised pseudo-labels of the example through abductive reasoning, which is compatible
with the knowledge base.
"""
symbol_num = data_example.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
if self.use_zoopt:
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
revision_idx = np.where(solution.get_x() != 0)[0]
candidates, reasoning_results = self.kb.revise_at_idx(
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
revision_idx=revision_idx,
)
else:
candidates, reasoning_results = self.kb.abduce_candidates(
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
max_revision_num=max_revision_num,
require_more_revision=self.require_more_revision,
)
candidate = self._get_one_candidate(data_example, candidates, reasoning_results)
return candidate
[docs]
def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
"""
Perform abductive reasoning on the given prediction data examples.
For detailed information, refer to ``abduce``.
"""
abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples]
data_examples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label
[docs]
def batch_supervised_abduce(self, data_examples: ListData) -> List[List[Any]]:
"""
Perform abductive reasoning on the given prediction data examples, using supervised data
when gt_pseudo_label is given.
"""
abduced_pseudo_label = [
(
data_example.gt_pseudo_label
if data_example.gt_pseudo_label
else self.abduce(data_example)
)
for data_example in data_examples
]
data_examples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label
def __call__(self, data_examples: ListData) -> List[List[Any]]:
return self.batch_abduce(data_examples)
# =============================================================================
# A3BL: Ambiguity-Aware Abductive Learning
#
# Reference: https://github.com/Hao-Yuan-He/A3BL
# =============================================================================
[docs]
class A3BLReasoner(Reasoner):
"""
Reasoner for minimizing the inconsistency between the knowledge base and learning models.
Parameters
----------
kb : class KBBase
The knowledge base to be used for reasoning.
dist_func : Union[str, Callable], optional
The distance function used to determine the cost list between each
candidate and the given prediction. The cost is also referred to as a consistency
measure, wherein the candidate with the lowest cost is selected as the final
abduced label. It can be either a string representing a predefined distance
function or a callable function. The available predefined distance functions:
'hamming' | 'confidence' | 'avg_confidence' | 'similarity' | 'rejection'.
See :class:`Reasoner` for the full description of each option.
Defaults to 'confidence'.
idx_to_label : dict, optional
A mapping from index in the base model to label. If not provided, a default
order-based index to label mapping is created. Defaults to None.
max_revision : Union[int, float], optional
The upper limit on the number of revisions for each data example when
performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the
number of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required
when performing abductive reasoning. Defaults to 0.
use_zoopt : bool, optional
Whether to use ZOOpt library during abductive reasoning. Defaults to False.
topK : int, optional
Number of top-ranked candidates to keep when forming the soft label. ``-1``
keeps all candidates. Defaults to 16.
temperature : float, optional
Softmax temperature used when aggregating candidate probabilities into a
soft label. Lower values produce sharper distributions. Defaults to 0.2.
multi_label : bool, optional
Whether the underlying task is multi-label (each symbol is a binary vector
rather than a single class index). Defaults to False.
"""
def __init__(
self,
kb,
dist_func="confidence",
idx_to_label=None,
max_revision: Union[int, float] = -1,
require_more_revision: int = 0,
use_zoopt: bool = False,
topK: int = 16,
temperature: float = 0.2,
multi_label: bool = False,
):
super().__init__(
kb, dist_func, idx_to_label, max_revision, require_more_revision, use_zoopt
)
import torch
self.topK = topK
self.temperature = temperature
self.class_num = len(self.kb.pseudo_label_list)
self.multi_label = multi_label
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def _confidence_dist(
self, pred_probs: np.ndarray, candidate_idxs: List[List[Any]], temp: float = 1.0
) -> np.ndarray:
from scipy.special import softmax
candidates_array = np.array(candidate_idxs)
_, symbol_num = candidates_array.shape
row_indices = np.arange(symbol_num)[:, np.newaxis]
selected_probs = pred_probs[row_indices, candidates_array.T]
candidate_probs = np.sum(selected_probs, axis=0) / temp
return softmax(candidate_probs)
def _confidence_dist_multi_label(
self, pred_probs: np.ndarray, candidate_idxs: List[List[Any]], temp: float = 1.0
) -> np.ndarray:
from scipy.special import softmax
candidate_probs = pred_probs @ np.array(candidate_idxs).T / temp
return softmax(candidate_probs.squeeze(axis=0))
def _candidates_idxs(self, candidates: List[List[Any]]):
return [[self.label_to_idx[x] for x in c] for c in candidates]
def _topk(
self, candidates: List[Any], candidate_probs: np.ndarray, K: int = -1
) -> Tuple[List[List[Any]], List[Any]]:
"""
Performs a top-k selection from the candidate_set based on candidate_probs.
If `K` is set to -1, all candidates are chosen.
Returns a tuple containing the selected candidates and their corresponding probabilities.
"""
import heapq
if K == -1 or len(candidates) <= K:
return candidates, candidate_probs
# Iterate over all candidates and maintain a heap of size K with the largest probabilities
heap = []
for i, (candidate, prob) in enumerate(zip(candidates, candidate_probs)):
if i < K:
heapq.heappush(heap, (prob, candidate))
else:
if prob > heap[0][0]:
heapq.heappop(heap)
heapq.heappush(heap, (prob, candidate))
# Extract top-k elements from the heap,
# and reverse them to get the highest probabilities first
topk_probs, topk_candidates = zip(*heap)
return list(topk_candidates), list(topk_probs)
[docs]
def multi_label_aggregate(self, candidates: List[List[int]], candidate_probs: List[float]):
"""
An multi-label version of A3BL.
"""
import torch
with torch.no_grad():
symbol_num = len(candidates[0])
aggregate_label = torch.zeros(size=(symbol_num, 1))
for candidate, prob in zip(candidates, candidate_probs):
for i, item in enumerate(candidate):
if item == 1:
aggregate_label[i] += prob
return list(aggregate_label.unbind(1))
[docs]
def aggregate(self, candidates: List[List[int]], candidate_probs: List[float]):
import torch
import torch.nn.functional as F
with torch.no_grad():
candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)
probs_tensor = torch.tensor(candidate_probs, device=self.device, dtype=torch.float32)
one_hot = F.one_hot(candidates_tensor, num_classes=self.class_num).float() # [N, M, C]
weighted_one_hot = one_hot * probs_tensor.unsqueeze(-1).unsqueeze(-1) # [N, M, C]
aggregate_label = weighted_one_hot.sum(dim=0) # [M, C]
return [tensor.cpu() for tensor in aggregate_label.unbind(0)]
[docs]
def abduce(self, data_example: ListData) -> Tuple[List[Any], List[Any]]:
"""
Perform abduction and get a soft label distribution aggregated from
all valid candidates that satisfy the underlying rules.
Parameters
----------
data_example : ListData
Data example.
Returns
-------
soft_label : List[Any]
Soft label aggregated from the top-k valid candidates.
pseudo_label : List[Any]
Hard pseudo-label revision (the top-1 candidate) that is
consistent with the knowledge base.
"""
max_revision_num = data_example.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, max_revision_num)
candidates, _ = self.kb.abduce_candidates(
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
max_revision_num=max_revision_num,
require_more_revision=self.require_more_revision,
)
if len(candidates) == 0:
return [], []
confidence_dist_cal = (
self._confidence_dist if not self.multi_label else self._confidence_dist_multi_label
)
candidate_probs = confidence_dist_cal(
data_example.pred_prob, self._candidates_idxs(candidates), self.temperature
)
topk_candidates, topk_candidates_probs = self._topk(candidates, candidate_probs, self.topK)
aggregated_labels = (
self.aggregate(topk_candidates, topk_candidates_probs)
if not self.multi_label
else self.multi_label_aggregate(topk_candidates, topk_candidates_probs)
)
return aggregated_labels, topk_candidates[0]
[docs]
def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
"""
Perform abductive reasoning on the given prediction data examples.
For detailed information, refer to ``abduce``.
"""
abduced_soft_label, abduced_pseudo_label = zip(
*[self.abduce(data_example) for data_example in data_examples]
)
data_examples.abduced_soft_label = abduced_soft_label
data_examples.abduced_pseudo_label = abduced_pseudo_label
return abduced_soft_label
def __call__(self, data_examples: ListData) -> List[List[Any]]:
return self.batch_abduce(data_examples)
# =============================================================================
# Verification Learning
#
# Walks the per-symbol probability lattice in descending joint-probability
# order, collecting the first top_k assignments that satisfy the knowledge
# base. Reference: https://github.com/VerificationLearning/VerificationLearning
# =============================================================================
[docs]
def enumerate_label_assignments(
pred_prob: np.ndarray, max_iter: int = 10000
) -> Iterator[Tuple[List[int], float, List[float]]]:
"""
Yield label-index assignments for a single data example in descending
joint-probability order. The walk is a Lawler-style best-first search:
each state is the tuple of per-symbol rank indices, and successors are
generated by advancing any one symbol to its next-best class.
Parameters
----------
pred_prob : np.ndarray
Per-symbol probability matrix with shape ``(num_symbols, num_classes)``.
max_iter : int, optional
Hard cap on the number of yields. Defaults to 10000.
Yields
------
labels : List[int]
Class indices for each symbol.
joint_prob : float
Product of the chosen per-symbol probabilities.
per_symbol_probs : List[float]
The chosen probability for each symbol.
"""
pred_prob = np.asarray(pred_prob, dtype=float)
num_symbols, num_classes = pred_prob.shape
if num_symbols == 0:
return
sorted_indices = np.argsort(-pred_prob, axis=1)
sorted_probs = np.take_along_axis(pred_prob, sorted_indices, axis=1)
initial_state = (0,) * num_symbols
initial_prob = float(np.prod(sorted_probs[:, 0]))
seen = {initial_state}
heap: List[Tuple[float, Tuple[int, ...]]] = [(-initial_prob, initial_state)]
yields = 0
while heap and yields < max_iter:
neg_prob, state = heapq.heappop(heap)
joint_prob = -neg_prob
labels = [int(sorted_indices[i, state[i]]) for i in range(num_symbols)]
per_symbol_probs = [float(sorted_probs[i, state[i]]) for i in range(num_symbols)]
yield labels, joint_prob, per_symbol_probs
yields += 1
for sym in range(num_symbols):
next_rank = state[sym] + 1
if next_rank >= num_classes:
continue
new_state = state[:sym] + (next_rank,) + state[sym + 1:]
if new_state in seen:
continue
seen.add(new_state)
current_p = sorted_probs[sym, state[sym]]
next_p = sorted_probs[sym, next_rank]
if current_p <= 0:
new_joint = 0.0
else:
new_joint = joint_prob * (next_p / current_p)
heapq.heappush(heap, (-new_joint, new_state))
[docs]
def top_k_satisfying(
pred_prob: np.ndarray,
predicate: Callable[[List[Any]], bool],
top_k: int = 1,
max_iter: int = 10000,
idx_to_label: Optional[dict] = None,
) -> Tuple[List[List[Any]], List[float]]:
"""
Walk label assignments in descending joint-probability order and return
the first ``top_k`` that satisfy ``predicate``. If none is found within
``max_iter`` iterations the single highest-probability assignment is
returned as a fallback so callers always receive a usable label.
Parameters
----------
pred_prob : np.ndarray
Per-symbol probability matrix with shape ``(num_symbols, num_classes)``.
predicate : Callable[[List[Any]], bool]
Function called on each candidate label sequence; truthy means the
candidate is consistent with the knowledge base.
top_k : int, optional
Maximum number of satisfying candidates to return. Defaults to 1.
max_iter : int, optional
Hard cap on enumeration steps. Defaults to 10000.
idx_to_label : dict, optional
Optional mapping from class index to pseudo-label. When omitted, the
raw class indices are returned.
Returns
-------
candidates : List[List[Any]]
Label assignments that satisfy ``predicate`` (or the fallback).
probs : List[float]
Joint probability of each returned candidate.
"""
matches: List[List[Any]] = []
probs: List[float] = []
fallback: Optional[Tuple[List[Any], float]] = None
for labels_idx, joint_prob, _ in enumerate_label_assignments(pred_prob, max_iter):
labels = (
[idx_to_label[i] for i in labels_idx] if idx_to_label is not None else labels_idx
)
if fallback is None:
fallback = (labels, joint_prob)
if predicate(labels):
matches.append(labels)
probs.append(joint_prob)
if len(matches) >= top_k:
break
if not matches and fallback is not None:
matches.append(fallback[0])
probs.append(fallback[1])
return matches, probs
[docs]
class VerificationReasoner:
"""
Reasoner used by :class:`~ablkit.bridge.VerificationBridge`. Rather than
picking a single best candidate via a distance function, it enumerates
the top ``top_k`` label assignments that satisfy the knowledge base,
ordered by joint probability. The bridge then trains the model on each
of those candidates.
Parameters
----------
kb : KBBase
The knowledge base used to verify candidates. ``kb.logic_forward``
must return the reasoning result so it can be compared with each
data example's ``Y``.
top_k : int, optional
Number of satisfying candidates to enumerate per example.
Defaults to 1.
max_iter : int, optional
Maximum number of enumeration steps per example before giving up
and returning the fallback. Defaults to 10000.
idx_to_label : dict, optional
A mapping from base-model index to pseudo-label. If omitted a
default order-based mapping is built from ``kb.pseudo_label_list``.
"""
def __init__(
self,
kb: KBBase,
top_k: int = 1,
max_iter: int = 10000,
idx_to_label: Optional[dict] = None,
) -> None:
if top_k < 1:
raise ValueError("top_k must be >= 1.")
if max_iter < 1:
raise ValueError("max_iter must be >= 1.")
self.kb = kb
self.top_k = top_k
self.max_iter = max_iter
if idx_to_label is None:
idx_to_label = dict(enumerate(kb.pseudo_label_list))
self.idx_to_label = idx_to_label
self.label_to_idx = {label: idx for idx, label in idx_to_label.items()}
[docs]
def top_k_candidates(
self, pred_prob: np.ndarray, y: Any
) -> Tuple[List[List[Any]], List[float]]:
"""
Return up to ``top_k`` label assignments for one data example whose
``kb.logic_forward`` matches ``y``.
"""
def predicate(labels: List[Any]) -> bool:
return self.kb.logic_forward(labels) == y
return top_k_satisfying(
pred_prob,
predicate,
top_k=self.top_k,
max_iter=self.max_iter,
idx_to_label=self.idx_to_label,
)
[docs]
def batch_top_k(self, data_examples) -> List[List[List[Any]]]:
"""
Run :meth:`top_k_candidates` on every example in ``data_examples``.
Stores the result on ``data_examples.top_k_candidates`` and
``data_examples.top_k_probs``. Returns the list of per-example
candidate lists.
"""
all_candidates: List[List[List[Any]]] = []
all_probs: List[List[float]] = []
for data_example in data_examples:
cands, probs = self.top_k_candidates(data_example.pred_prob, data_example.Y)
all_candidates.append(cands)
all_probs.append(probs)
data_examples.top_k_candidates = all_candidates
data_examples.top_k_probs = all_probs
return all_candidates