"""
This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy.
Copyright (c) 2024 LAMDA. All rights reserved.
"""
import numpy as np
from ..structures import ListData
from .base_metric import BaseMetric
[docs]
class SymbolAccuracy(BaseMetric):
"""
A metrics class for evaluating symbol-level accuracy.
This class is designed to assess the accuracy of symbol prediction. Symbol accuracy
is calculated by comparing predicted presudo labels and their ground truth.
Parameters
----------
prefix : str, optional
The prefix that will be added to the metrics names to disambiguate homonymous
metrics of different tasks. Inherits from BaseMetric. Defaults to None.
"""
[docs]
def process(self, data_examples: ListData) -> None:
"""
Processes a batch of data examples.
This method takes in a batch of data examples, each containing a list of predicted
pseudo-labels (pred_pseudo_label) and their ground truth (gt_pseudo_label). It
calculates the accuracy by comparing the two lists. Then, a tuple of correct symbol
count and total symbol count is appended to ``self.results``.
Parameters
----------
data_examples : ListData
A batch of data examples, each containing:
- ``pred_pseudo_label``: List of predicted pseudo-labels.
- ``gt_pseudo_label``: List of ground truth pseudo-labels.
Raises
------
ValueError
If the lengths of predicted and ground truth symbol lists are not equal.
"""
pred_pseudo_label_list = data_examples.flatten("pred_pseudo_label")
gt_pseudo_label_list = data_examples.flatten("gt_pseudo_label")
if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")
correct_num = np.sum(np.array(pred_pseudo_label_list) == np.array(gt_pseudo_label_list))
self.results.append((correct_num, len(pred_pseudo_label_list)))
[docs]
def compute_metrics(self) -> dict:
"""
Compute the symbol accuracy metrics from ``self.results``. It calculates the
percentage of correctly predicted pseudo-labels over all pseudo-labels.
Returns
-------
dict
A dictionary containing the computed metrics. It includes the key
'character_accuracy' which maps to the calculated symbol-level accuracy,
represented as a float between 0 and 1.
"""
results = self.results
metrics = dict()
metrics["character_accuracy"] = sum(t[0] for t in results) / sum(t[1] for t in results)
return metrics