ablkit.data

structures

class ablkit.data.structures.ListData(*, metainfo: dict | None = None, **kwargs)[source]

Bases: BaseDataElement

Abstract Data Interface used throughout the ABLkit.

ListData is the underlying data structure used in the ABLkit, designed to manage diverse forms of data dynamically generated throughout the Abductive Learning (ABL) framework. This includes handling raw data, predicted pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc.

As a fundamental data structure in ABL, ListData is essential for the smooth transfer and manipulation of data across various components of the ABL framework, such as prediction, abductive reasoning, and training phases. It provides a unified data format across these stages, ensuring compatibility and flexibility in handling diverse data forms in the ABL framework.

The attributes in ListData are divided into two parts, the metainfo and the data respectively.

  • metainfo: Usually used to store basic information about data examples, such as symbol number, image size, etc. The attributes can be accessed or modified by dict-like or object-like operations, such as . (for data access and modification), in, del, pop(str), get(str), metainfo_keys(), metainfo_values(), metainfo_items(), set_metainfo() (for set or change key-value pairs in metainfo).

  • data: raw data, labels, predictions, and abduced results are stored. The attributes can be accessed or modified by dict-like or object-like operations, such as ., in, del, pop(str), get(str), keys(), values(), items(). Users can also apply tensor-like methods to all torch.Tensor in the data_fields, such as .cuda(), .cpu(), .numpy(), .to(), to_tensor(), .detach().

ListData supports index and slice for data field. The type of value in data field can be either None or list of base data structures such as torch.Tensor, numpy.ndarray, list, str and tuple.

This design is inspired by and extends the functionalities of the BaseDataElement class implemented in MMEngine.

Examples

>>> from ablkit.data.structures import ListData
>>> import numpy as np
>>> import torch
>>> data_examples = ListData()
>>> data_examples.X = [list(torch.randn(2)) for _ in range(3)]
>>> data_examples.Y = [1, 2, 3]
>>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
>>> len(data_examples)
3
>>> print(data_examples)
<ListData(
    META INFORMATION
    DATA FIELDS
    Y: [1, 2, 3]
    gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
    X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]]
) at 0x7f3bbf1991c0>
>>> print(data_examples[:1])
<ListData(
    META INFORMATION
    DATA FIELDS
    Y: [1]
    gt_pseudo_label: [[1, 2]]
    X: [[tensor(1.1949), tensor(-0.9378)]]
) at 0x7f3bbf1a3580>
>>> print(data_examples.elements_num("X"))
6
>>> print(data_examples.flatten("gt_pseudo_label"))
[1, 2, 3, 4, 5, 6]
>>> print(data_examples.to_tuple("Y"))
(1, 2, 3)
elements_num(item: str) int[source]

Return the number of elements in the attribute specified by item.

Parameters:

item (str) – Name of the attribute for which the number of elements is to be determined.

Returns:

The number of elements in the attribute specified by item.

Return type:

int

flatten(item: str) List[source]

Flatten the list of the attribute specified by item.

Parameters:

item – Name of the attribute to be flattened.

Returns:

The flattened list of the attribute specified by item.

Return type:

list

to_tuple(item: str) tuple[source]

Convert the attribute specified by item to a tuple.

Parameters:

item (str) – Name of the attribute to be converted.

Returns:

The attribute after conversion to a tuple.

Return type:

tuple

evaluation

class ablkit.data.evaluation.BaseMetric(prefix: str | None = None)[source]

Bases: object

Base class for a metrics.

The metrics first processes each batch of data_examples and appends the processed results to the results list. Then, it computes the metrics of the entire dataset.

Parameters:

prefix (str, optional) – The prefix that will be added in the metrics names to disambiguate homonymous metrics of different tasks. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.

abstract compute_metrics() dict[source]

Compute the metrics from processed results.

Returns:

The computed metrics. The keys are the names of the metrics, and the values are the corresponding results.

Return type:

dict

evaluate() dict[source]

Evaluate the model performance of the whole dataset after processing all batches.

Returns:

Evaluation metrics dict on the val dataset. The keys are the names of the metrics, and the values are the corresponding results.

Return type:

dict

abstract process(data_examples: ListData) None[source]

Process one batch of data examples. The processed results should be stored in self.results, which will be used to compute the metrics when all batches have been processed.

Parameters:

data_examples (ListData) – A batch of data examples.

class ablkit.data.evaluation.ReasoningMetric(kb: KBBase, prefix: str | None = None)[source]

Bases: BaseMetric

A metrics class for evaluating the model performance on tasks that need reasoning.

This class is designed to calculate the accuracy of the reasoing results. Reasoning results are generated by first using the learning part to predict pseudo-labels and then using a knowledge base (KB) to perform logical reasoning. The reasoning results are then compared with the ground truth to calculate the accuracy.

Parameters:
  • kb (KBBase) – An instance of a knowledge base, used for logical reasoning and validation. If not provided, reasoning checks are not performed. Defaults to None.

  • prefix (str, optional) – The prefix that will be added to the metrics names to disambiguate homonymous metrics of different tasks. Inherits from BaseMetric. Defaults to None.

Notes

The ReasoningMetric expects data_examples to have the attributes pred_pseudo_label, Y, and X, corresponding to the predicted pseduo labels, ground truth of reasoning results, and input data, respectively.

compute_metrics() dict[source]

Compute the reasoning accuracy metrics from self.results. It calculates the percentage of correctly reasoned examples over all examples.

Returns:

A dictionary containing the computed metrics. It includes the key ‘reasoning_accuracy’ which maps to the calculated reasoning accuracy, represented as a float between 0 and 1.

Return type:

dict

process(data_examples: ListData) None[source]

Process a batch of data examples.

This method takes in a batch of data examples, each containing predicted pseudo-labels (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It evaluates the reasoning accuracy of each example by comparing the logical reasoning result (derived using the knowledge base) of the predicted pseudo-labels against Y The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended to self.results.

Parameters:

data_examples (ListData) – A batch of data examples.

class ablkit.data.evaluation.SymbolAccuracy(prefix: str | None = None)[source]

Bases: BaseMetric

A metrics class for evaluating symbol-level accuracy.

This class is designed to assess the accuracy of symbol prediction. Symbol accuracy is calculated by comparing predicted presudo labels and their ground truth.

Parameters:

prefix (str, optional) – The prefix that will be added to the metrics names to disambiguate homonymous metrics of different tasks. Inherits from BaseMetric. Defaults to None.

compute_metrics() dict[source]

Compute the symbol accuracy metrics from self.results. It calculates the percentage of correctly predicted pseudo-labels over all pseudo-labels.

Returns:

A dictionary containing the computed metrics. It includes the key ‘character_accuracy’ which maps to the calculated symbol-level accuracy, represented as a float between 0 and 1.

Return type:

dict

process(data_examples: ListData) None[source]

Processes a batch of data examples.

This method takes in a batch of data examples, each containing a list of predicted pseudo-labels (pred_pseudo_label) and their ground truth (gt_pseudo_label). It calculates the accuracy by comparing the two lists. Then, a tuple of correct symbol count and total symbol count is appended to self.results.

Parameters:

data_examples (ListData) – A batch of data examples, each containing: - pred_pseudo_label: List of predicted pseudo-labels. - gt_pseudo_label: List of ground truth pseudo-labels.

Raises:

ValueError – If the lengths of predicted and ground truth symbol lists are not equal.