"""
This module contains the ReasoningMetric, which is used for evaluating the model performance
on tasks that need reasoning.
Copyright (c) 2024 LAMDA. All rights reserved.
"""
from typing import Optional
from ...reasoning import KBBase
from ..structures import ListData
from .base_metric import BaseMetric
[docs]
class ReasoningMetric(BaseMetric):
"""
A metrics class for evaluating the model performance on tasks that need reasoning.
This class is designed to calculate the accuracy of the reasoing results. Reasoning
results are generated by first using the learning part to predict pseudo-labels
and then using a knowledge base (KB) to perform logical reasoning. The reasoning results
are then compared with the ground truth to calculate the accuracy.
Parameters
----------
kb : KBBase
An instance of a knowledge base, used for logical reasoning and validation.
If not provided, reasoning checks are not performed. Defaults to None.
prefix : str, optional
The prefix that will be added to the metrics names to disambiguate homonymous
metrics of different tasks. Inherits from BaseMetric. Defaults to None.
Notes
-----
The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`,
`Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning
results, and input data, respectively.
"""
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb
# pylint: disable=protected-access
[docs]
def process(self, data_examples: ListData) -> None:
"""
Process a batch of data examples.
This method takes in a batch of data examples, each containing predicted pseudo-labels
(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
evaluates the reasoning accuracy of each example by comparing the logical reasoning
result (derived using the knowledge base) of the predicted pseudo-labels against Y
The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended
to ``self.results``.
Parameters
----------
data_examples : ListData
A batch of data examples.
"""
pred_pseudo_label_list = data_examples.pred_pseudo_label
y_list = data_examples.Y
x_list = data_examples.X
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
if self.kb._check_equal(
self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y
):
self.results.append(1)
else:
self.results.append(0)
[docs]
def compute_metrics(self) -> dict:
"""
Compute the reasoning accuracy metrics from ``self.results``. It calculates the
percentage of correctly reasoned examples over all examples.
Returns
-------
dict
A dictionary containing the computed metrics. It includes the key
'reasoning_accuracy' which maps to the calculated reasoning accuracy,
represented as a float between 0 and 1.
"""
results = self.results
metrics = dict()
metrics["reasoning_accuracy"] = sum(results) / len(results)
return metrics