Learn the Basics || Quick Start || Dataset & Data Structure || Learning Part || Reasoning Part || Evaluation Metrics || Bridge
Evaluation Metrics
In this section, we will look at how to build evaluation metrics.
from ablkit.data.evaluation import BaseMetric, SymbolAccuracy, ReasoningMetric
ABLkit seperates the evaluation process from model training and testing as an independent class, BaseMetric. The training and testing processes are implemented in the BaseBridge class, so metrics are used by this class and its sub-classes. After building a bridge with a list of BaseMetric instances, these metrics will be used by the bridge.valid method to evaluate the model performance during training and testing.
To customize our own metrics, we need to inherit from BaseMetric and implement the process and compute_metrics methods.
The
processmethod accepts a batch of model prediction and saves the information toself.resultsproperty after processing this batch.The
compute_metricsmethod uses all the information saved inself.resultsto calculate and return a dict that holds the evaluation results.
Besides, we can assign a str to the prefix argument of the __init__ function. This string is automatically prefixed to the output metric names. For example, if we set prefix="mnist_add", the output metric name will be character_accuracy.
We provide two basic metrics, namely SymbolAccuracy and ReasoningMetric, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively. Using SymbolAccuracy as an example, the following code shows how to implement a custom metric.
class SymbolAccuracy(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
# prefix is used to distinguish different metrics
super().__init__(prefix)
def process(self, data_examples: Sequence[dict]) -> None:
# pred_pseudo_label and gt_pseudo_label are both of type List[List[Any]]
# and have the same length
pred_pseudo_label = data_examples.pred_pseudo_label
gt_pseudo_label = data_examples.gt_pseudo_label
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0
for pred_symbol, symbol in zip(pred_z, z):
if pred_symbol == symbol:
correct_num += 1
self.results.append(correct_num / len(z))
def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["character_accuracy"] = sum(results) / len(results)
return metrics