"""
This module contains the classes KBBase, GroundKB, and PrologKB, which provide wrappers
for different kinds of knowledge bases.
Copyright (c) 2024 LAMDA. All rights reserved.
"""
import bisect
import inspect
import logging
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import combinations, product
from multiprocessing import Pool
from typing import Any, Callable, List, Optional
import numpy as np
from ..utils.cache import abl_cache
from ..utils.logger import print_log
from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable
[docs]
class KBBase(ABC):
"""
Base class for knowledge base.
Parameters
----------
pseudo_label_list : List[Any]
List of possible pseudo-labels. It's recommended to arrange the pseudo-labels in this
list so that each aligns with its corresponding index in the base model: the first with
the 0th index, the second with the 1st, and so forth.
max_err : float, optional
The upper tolerance limit when comparing the similarity between the reasoning result of
pseudo-labels and the ground truth. This is only applicable when the reasoning
result is of a numerical type. This is particularly relevant for regression problems where
exact matches might not be feasible. Defaults to 1e-10.
use_cache : bool, optional
Whether to use abl_cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
key_func : Callable, optional
A function employed for hashing in abl_cache. This is only operational when use_cache
is set to True. Defaults to ``to_hashable``.
cache_size: int, optional
The cache size in abl_cache. This is only operational when use_cache is set to
True. Defaults to 4096.
Notes
-----
Users should derive from this base class to build their own knowledge base. For the
user-build KB (a derived subclass), it's only required for the user to provide the
``pseudo_label_list`` and override the ``logic_forward`` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
"""
def __init__(
self,
pseudo_label_list: List[Any],
max_err: float = 1e-10,
use_cache: bool = True,
key_func: Callable = to_hashable,
cache_size: int = 4096,
):
if not isinstance(pseudo_label_list, list):
raise TypeError(f"pseudo_label_list should be list, got {type(pseudo_label_list)}")
self.pseudo_label_list = pseudo_label_list
self.max_err = max_err
self.use_cache = use_cache
self.key_func = key_func
self.cache_size = cache_size
argspec = inspect.getfullargspec(self.logic_forward)
self._num_args = len(argspec.args) - 1
if (
self._num_args == 2 and self.use_cache
): # If the logic_forward function has 2 arguments, then disable cache
self.use_cache = False
print_log(
"The logic_forward function has 2 arguments, so the cache is disabled. ",
logger="current",
level=logging.WARNING,
)
[docs]
@abstractmethod
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:
"""
How to perform (deductive) logical reasoning, i.e. matching an example's
pseudo-labels to its reasoning result. Users are required to provide this.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example.
x : List[Any], optional
The example. If deductive logical reasoning does not require any
information from the example, the overridden function provided by the user can omit
this parameter.
Returns
-------
Any
The reasoning result.
"""
[docs]
def abduce_candidates(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
Perform abductive reasoning to get a candidate compatible with the knowledge base.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The example. If the information from the example
is not required in the reasoning process, then this parameter will not have
any effect.
max_revision_num : int
The upper limit on the number of revised labels for each example.
require_more_revision : int
Specifies additional number of revisions permitted beyond the minimum required.
Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two elements. The first element is a list of candidate revisions,
i.e. revised pseudo-labels of the example. that are compatible with the knowledge
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision)
def _check_equal(self, reasoning_result: Any, y: Any) -> bool:
"""
Check whether the reasoning result of a pseduo label example is equal to the ground truth
(or, within the maximum error allowed for numerical results).
Returns
-------
bool
The result of the check.
"""
if reasoning_result is None:
return False
if isinstance(reasoning_result, (int, float)) and isinstance(y, (int, float)):
return abs(reasoning_result - y) <= self.max_err
else:
return reasoning_result == y
[docs]
def revise_at_idx(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> List[List[Any]]:
"""
Revise the pseudo-labels at specified index positions.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example (to be revised).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The example. If the information from the example
is not required in the reasoning process, then this parameter will not have
any effect.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo-labels.
Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two elements. The first element is a list of candidate revisions,
i.e. revised pseudo-labels of the example. that are compatible with the knowledge
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
candidates, reasoning_results = [], []
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
for c in abduce_c:
candidate = pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
reasoning_result = self.logic_forward(candidate, *(x,) if self._num_args == 2 else ())
if self._check_equal(reasoning_result, y):
candidates.append(candidate)
reasoning_results.append(reasoning_result)
return candidates, reasoning_results
def _revision(
self,
revision_num: int,
pseudo_label: List[Any],
y: Any,
x: List[Any],
) -> List[List[Any]]:
"""
For a specified number of labels in an example's pseudo-labels to revise, iterate through
all possible indices to find any candidates that are compatible with the knowledge base.
"""
new_candidates, new_reasoning_results = [], []
revision_idx_list = combinations(range(len(pseudo_label)), revision_num)
for revision_idx in revision_idx_list:
candidates, reasoning_results = self.revise_at_idx(pseudo_label, y, x, revision_idx)
new_candidates.extend(candidates)
new_reasoning_results.extend(reasoning_results)
return new_candidates, new_reasoning_results
@abl_cache()
def _abduce_by_search(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
Perform abductive reasoning by exhaustive search. Specifically, begin with 0 and
continuously increase the number of labels to revise, until
candidates that are compatible with the knowledge base are found.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example (to be revised).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The example. If the information from the example
is not required in the reasoning process, then this parameter will not have
any effect.
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
If larger than 0, then after having found any candidates compatible with the
knowledge base, continue to increase the number of labels to
revise to get more possible compatible candidates.
Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two elements. The first element is a list of candidate revisions,
i.e. revised pseudo-labels of the example. that are compatible with the knowledge
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
candidates, reasoning_results = [], []
for revision_num in range(len(pseudo_label) + 1):
new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x)
candidates.extend(new_candidates)
reasoning_results.extend(new_reasoning_results)
if len(candidates) > 0:
min_revision_num = revision_num
break
if revision_num >= max_revision_num:
return [], []
for revision_num in range(
min_revision_num + 1, min_revision_num + require_more_revision + 1
):
if revision_num > max_revision_num:
return candidates, reasoning_results
new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x)
candidates.extend(new_candidates)
reasoning_results.extend(new_reasoning_results)
return candidates, reasoning_results
def __repr__(self):
return (
f"{self.__class__.__name__} is a KB with "
f"pseudo_label_list={self.pseudo_label_list!r}, "
f"max_err={self.max_err!r}, "
f"use_cache={self.use_cache!r}."
)
[docs]
class GroundKB(KBBase):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, storing all potential candidates along with their respective
reasoning result. Ground KB can accelerate abductive reasoning in ``abduce_candidates``.
Parameters
----------
pseudo_label_list : List[Any]
Refer to class ``KBBase``.
GKB_len_list : List[int]
List of possible lengths for pseudo-labels of an example.
max_err : float, optional
Refer to class ``KBBase``.
Notes
-----
Users can also inherit from this class to build their own knowledge base. Similar
to ``KBBase``, users are only required to provide the ``pseudo_label_list`` and override
the ``logic_forward`` function. Additionally, users should provide the ``GKB_len_list``.
After that, other operations (e.g. auto-construction of GKB, and how to perform
abductive reasoning) will be automatically set up.
"""
def __init__(
self,
pseudo_label_list: List[Any],
GKB_len_list: List[int],
max_err: float = 1e-10,
):
super().__init__(pseudo_label_list, max_err)
if not isinstance(GKB_len_list, list):
raise TypeError("GKB_len_list should be list, but got {type(GKB_len_list)}")
if self._num_args == 2:
raise NotImplementedError(
"GroundKB only supports 1-argument logic_forward, but got "
+ f"{self._num_args}-argument logic_forward"
)
self.GKB_len_list = GKB_len_list
self.GKB = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.GKB.setdefault(len(x), defaultdict(list))[y].append(x)
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y is not None:
XY_list.append((x, y))
return XY_list
def _get_GKB(self):
"""
Prebuild the GKB according to ``pseudo_label_list`` and ``GKB_len_list``.
"""
X, Y = [], []
for length in self.GKB_len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y
[docs]
def abduce_candidates(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
Perform abductive reasoning by directly retrieving compatible candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
avoided.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The example (unused in GroundKB).
max_revision_num : int
The upper limit on the number of revised labels for each example.
require_more_revision : int
Specifies additional number of revisions permitted beyond the minimum required.
Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two elements. The first element is a list of candidate revisions,
i.e. revised pseudo-labels of the example. that are compatible with the knowledge
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
if not self.GKB or len(pseudo_label) not in self.GKB_len_list:
return [], []
all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y)
if len(all_candidates) == 0:
return [], []
cost_list = hamming_dist(pseudo_label, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
reasoning_results = [all_reasoning_results[idx] for idx in idxs]
return candidates, reasoning_results
def _find_candidate_GKB(self, pseudo_label: List[Any], y: Any) -> List[List[Any]]:
"""
Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results,
return all candidates and their corresponding reasoning results which fall within the
[y - max_err, y + max_err] range.
"""
if isinstance(y, (int, float)):
potential_candidates = self.GKB[len(pseudo_label)]
key_list = list(potential_candidates.keys())
low_key = bisect.bisect_left(key_list, y - self.max_err)
high_key = bisect.bisect_right(key_list, y + self.max_err)
all_candidates, all_reasoning_results = [], []
for key in key_list[low_key:high_key]:
for candidate in potential_candidates[key]:
all_candidates.append(candidate)
all_reasoning_results.append(key)
else:
all_candidates = self.GKB[len(pseudo_label)][y]
all_reasoning_results = [y] * len(all_candidates)
return all_candidates, all_reasoning_results
def __repr__(self):
GKB_info_parts = []
for i in self.GKB_len_list:
num_candidates = len(self.GKB[i]) if i in self.GKB else 0
GKB_info_parts.append(f"{num_candidates} candidates of length {i}")
GKB_info = ", ".join(GKB_info_parts)
return (
f"{self.__class__.__name__} is a KB with "
f"pseudo_label_list={self.pseudo_label_list!r}, "
f"max_err={self.max_err!r}, "
f"use_cache={self.use_cache!r}. "
f"It has a prebuilt GKB with "
f"GKB_len_list={self.GKB_len_list!r}, "
f"and there are "
f"{GKB_info}"
f" in the GKB."
)
[docs]
class PrologKB(KBBase):
"""
Knowledge base provided by a Prolog (.pl) file.
Parameters
----------
pseudo_label_list : List[Any]
Refer to class ``KBBase``.
pl_file : str
Prolog file containing the KB.
Notes
-----
Users can instantiate this class to build their own knowledge base. During the
instantiation, users are only required to provide the ``pseudo_label_list`` and ``pl_file``.
To use the default logic forward and abductive reasoning methods in this class, in the
Prolog (.pl) file, there needs to be a rule which is strictly formatted as
``logic_forward(Pseudo_labels, Res).``, e.g., ``logic_forward([A,B], C) :- C is A+B``.
For specifics, refer to the ``logic_forward`` and ``get_query_string`` functions in this
class. Users are also welcome to override related functions for more flexible support.
"""
def __init__(self, pseudo_label_list: List[Any], pl_file: str):
super().__init__(pseudo_label_list)
try:
import pyswip # pylint: disable=import-outside-toplevel
except (IndexError, ImportError):
print_log(
"A Prolog-based knowledge base is in use. Please install SWI-Prolog using "
"the command 'sudo apt-get install swi-prolog' for Linux users, or download "
"it following the guide in "
"https://github.com/yuce/pyswip/blob/master/INSTALL.md "
"for Windows and Mac users.",
logger="current",
level=logging.WARNING,
)
self.prolog = pyswip.Prolog()
self.pl_file = pl_file
if not os.path.exists(self.pl_file):
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
self.prolog.consult(self.pl_file)
[docs]
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:
"""
Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the
returned ``Res`` as the reasoning results. To use this default function, there must be
a ``logic_forward`` method in the pl file to perform reasoning.
Otherwise, users would override this function.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
"""
result = list(self.prolog.query(f"logic_forward({pseudo_label}, Res)."))[0]["Res"]
if result == "true":
return True
if result == "false":
return False
return result
def _revision_pseudo_label(
self,
pseudo_label: List[Any],
revision_idx: List[int],
) -> List[Any]:
import re # pylint: disable=import-outside-toplevel
revision_pseudo_label = pseudo_label.copy()
revision_pseudo_label = flatten(revision_pseudo_label)
for idx in revision_idx:
revision_pseudo_label[idx] = "P" + str(idx)
revision_pseudo_label = reform_list(revision_pseudo_label, pseudo_label)
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label))
[docs]
def get_query_string(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any], # pylint: disable=unused-argument
revision_idx: List[int],
) -> str:
"""
Get the query to be used for consulting Prolog.
This is a default function for demo, users would override this function to adapt to
their own Prolog file. In this demo function, return query
``logic_forward([kept_labels, Revise_labels], Res).``.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo-labels.
Returns
-------
str
A string of the query.
"""
query_string = "logic_forward("
query_string += self._revision_pseudo_label(pseudo_label, revision_idx)
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None)
query_string += f",{y})." if not key_is_none_flag else ")."
return query_string
[docs]
def revise_at_idx(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> List[List[Any]]:
"""
Revise the pseudo-labels at specified index positions by querying Prolog.
Parameters
----------
pseudo_label : List[Any]
Pseudo-labels of an example (to be revised).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo-labels.
Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two elements. The first element is a list of candidate revisions,
i.e. revised pseudo-labels of the example. that are compatible with the knowledge
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
candidates, reasoning_results = [], []
query_string = self.get_query_string(pseudo_label, y, x, revision_idx)
save_pseudo_label = pseudo_label
pseudo_label = flatten(pseudo_label)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
for c in abduce_c:
candidate = pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_list(candidate, save_pseudo_label)
candidates.append(candidate)
reasoning_results.append(y)
return candidates, reasoning_results
def __repr__(self):
return (
f"{self.__class__.__name__} is a KB with "
f"pseudo_label_list={self.pseudo_label_list!r}, "
f"defined by "
f"Prolog file {self.pl_file!r}."
)