Handwritten Equation Decipherment (HED)

For detailed code implementation, please view it on GitHub.

Below shows an implementation of Handwritten Equation Decipherment. In this task, the handwritten equations are given, which consist of sequential pictures of characters. The equations are generated with unknown operation rules from images of symbols (‘0’, ‘1’, ‘+’ and ‘=’), and each equation is associated with a label indicating whether the equation is correct (i.e., positive) or not (i.e., negative). Also, we are given a knowledge base which involves the structure of the equations and a recursive definition of bit-wise operations. The task is to learn from a training set of above-mentioned equations and then to predict labels of unseen equations.

Intuitively, we first use a machine learning model (learning part) to obtain the pseudo-labels (‘0’, ‘1’, ‘+’ and ‘=’) for the observed pictures. We then use the knowledge base (reasoning part) to perform abductive reasoning so as to yield ground hypotheses as possible explanations to the observed facts, suggesting some pseudo-labels to be revised. This process enables us to further update the machine learning model.

# Import necessary libraries and modules
import os.path as osp

import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from ablkit.learning import ABLModel, BasicNN
from ablkit.utils import ABLLogger, print_log

from bridge import HedBridge
from consistency_metric import ConsistencyMetric
from datasets import get_dataset, split_equation
from models.nn import SymbolNet
from reasoning import HedKB, HedReasoner

Working with Data

First, we get the datasets of handwritten equations:

total_train_data = get_dataset(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_dataset(train=False)

The datasets are shown below:

true_train_equation = train_data[1]
false_train_equation = train_data[0]
print(f"Equations in the dataset is organized by equation length, " +
      f"from {min(train_data[0].keys())} to {max(train_data[0].keys())}.")
print()

true_train_equation_with_length_5 = true_train_equation[5]
false_train_equation_with_length_5 = false_train_equation[5]
print(f"For each eqaation length, there are {len(true_train_equation_with_length_5)} " +
      f"true equations and {len(false_train_equation_with_length_5)} false equations " +
      f"in the training set.")

true_val_equation = val_data[1]
false_val_equation = val_data[0]
true_val_equation_with_length_5 = true_val_equation[5]
false_val_equation_with_length_5 = false_val_equation[5]
print(f"For each equation length, there are {len(true_val_equation_with_length_5)} " +
      f"true equations and {len(false_val_equation_with_length_5)} false equations " +
      f"in the validation set.")

true_test_equation = test_data[1]
false_test_equation = test_data[0]
true_test_equation_with_length_5 = true_test_equation[5]
false_test_equation_with_length_5 = false_test_equation[5]
print(f"For each equation length, there are {len(true_test_equation_with_length_5)} " +
      f"true equations and {len(false_test_equation_with_length_5)} false equations " +
      f"in the test set.")
Out:
Equations in the dataset is organized by equation length, from 5 to 26.

For each equation length, there are 225 true equations and 225 false equations in the training set.
For each equation length, there are 75 true equations and 75 false equations in the validation set.
For each equation length, there are 300 true equations and 300 false equations in the test set.

As illustrations, we show four equations in the training dataset:

true_train_equation_with_length_5 = true_train_equation[5]
true_train_equation_with_length_8 = true_train_equation[8]
print(f"First true equation with length 5 in the training dataset:")
for i, x in enumerate(true_train_equation_with_length_5[0]):
    plt.subplot(1, 5, i+1)
    plt.axis('off')
    plt.imshow(x.squeeze(), cmap='gray')
plt.show()
print(f"First true equation with length 8 in the training dataset:")
for i, x in enumerate(true_train_equation_with_length_8[0]):
    plt.subplot(1, 8, i+1)
    plt.axis('off')
    plt.imshow(x.squeeze(), cmap='gray')
plt.show()

false_train_equation_with_length_5 = false_train_equation[5]
false_train_equation_with_length_8 = false_train_equation[8]
print(f"First false equation with length 5 in the training dataset:")
for i, x in enumerate(false_train_equation_with_length_5[0]):
    plt.subplot(1, 5, i+1)
    plt.axis('off')
    plt.imshow(x.squeeze(), cmap='gray')
plt.show()
print(f"First false equation with length 8 in the training dataset:")
for i, x in enumerate(false_train_equation_with_length_8[0]):
    plt.subplot(1, 8, i+1)
    plt.axis('off')
    plt.imshow(x.squeeze(), cmap='gray')
plt.show()
Out:
First true equation with length 5 in the training dataset:
../_images/hed_dataset1.png
First true equation with length 8 in the training dataset:
../_images/hed_dataset2.png
First false equation with length 5 in the training dataset:
../_images/hed_dataset3.png
First false equation with length 8 in the training dataset:
../_images/hed_dataset4.png

Building the Learning Part

To build the learning part, we need to first build a machine learning base model. We use SymbolNet, and encapsulate it within a BasicNN object to create the base model. BasicNN is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface.

# class of symbol may be one of ['0', '1', '+', '='], total of 4 classes
net = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001, weight_decay=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model = BasicNN(
    net,
    loss_fn,
    optimizer,
    device=device,
    batch_size=32,
    num_epochs=1,
    stop_loss=None,
)

However, the base model built above deals with instance-level data (i.e., individual images), and can not directly deal with example-level data (i.e., a list of images comprising the equation). Therefore, we wrap the base model into ABLModel, which enables the learning part to train, test, and predict on example-level data.

model = ABLModel(base_model)

Building the Reasoning Part

In the reasoning part, we first build a knowledge base. As mentioned before, the knowledge base in this task involves the structure of the equations and a recursive definition of bit-wise operations, which are defined in Prolog file examples/hed/reasoning/BK.pl and examples/hed/reasoning/learn_add.pl, respectively. Specifically, the knowledge about the structure of equations is a set of DCG rules recursively define that a digit is a sequence of ‘0’ and ‘1’, and equations share the structure of X+Y=Z, though the length of X, Y and Z can be varied. The knowledge about bit-wise operations is a recursive logic program, which reversely calculates X+Y, i.e., it operates on X and Y digit-by-digit and from the last digit to the first.

The knowledge base is already built in HedKB. HedKB is derived from class PrologKB, and is built upon the aformentioned Prolog files.

kb = HedKB()

Note

Please notice that, the specific rules for calculating the operations are undefined in the knowledge base, i.e., results of ‘0+0’, ‘0+1’ and ‘1+1’ could be ‘0’, ‘1’, ‘00’, ‘01’ or even ‘10’. The missing calculation rules are required to be learned from the data. Therefore, HedKB incorporates methods for abducing rules from data. Users interested can refer to the specific implementation of HedKB in examples/hed/reasoning/reasoning.py

Then, we create a reasoner. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible with the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency.

In this task, we create the reasoner by instantiating the class HedReasoner, which is a reasoner derived from Reasoner and tailored specifically for this task. HedReasoner leverages ZOOpt library for acceleration, and has designed a specific strategy to better harness ZOOpt’s capabilities. Additionally, methods for abducing rules from data have been incorporated. Users interested can refer to the specific implementation of HedReasoner in reasoning/reasoning.py.

reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10)

Building Evaluation Metrics

Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use SymbolAccuracy and ReasoningMetric, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively.

# Set up metrics
metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")]

Bridging Learning and Reasoning

Now, the last step is to bridge the learning and reasoning part. We proceed with this step by creating an instance of HedBridge, which is derived from SimpleBridge and tailored specific for this task.

bridge = HedBridge(model, reasoner, metric_list)

Perform pretraining, training and testing by invoking the pretrain, train and test methods of HedBridge.

# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

bridge.pretrain("./weights")
bridge.train(train_data, val_data, save_dir=weights_dir)
bridge.test(test_data)