Source code for ablkit.learning.torch_dataset.multi_label_classification_dataset

"""
Implementation of PyTorch dataset class used for multi-label classification.

Copyright (c) 2024 LAMDA.  All rights reserved.
"""

from typing import Any, Callable, List, Optional

import numpy as np
import torch

from .classification_dataset import ClassificationDataset


[docs] class MultiLabelClassificationDataset(ClassificationDataset): """ Dataset used for multi-label classification, where each target ``Y[i]`` is a binary indicator vector (one entry per label) rather than a single class index. ``Y`` is stored as a ``float32`` tensor so it can be fed directly into ``BCEWithLogitsLoss`` and similar losses. Parameters ---------- X : List[Any] The input data. Y : List[Any] The per-sample label vectors. Each entry is converted via ``numpy.stack`` and stored as a ``FloatTensor``. transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version. Defaults to None. """ def __init__( self, X: List[Any], Y: List[Any], transform: Optional[Callable[..., Any]] = None, ) -> None: if (not isinstance(X, list)) or (not isinstance(Y, list)): raise ValueError("X and Y should be of type list.") if len(X) != len(Y): raise ValueError("Length of X and Y must be equal.") self.X = X self.Y = torch.FloatTensor(np.stack(Y, axis=0)) self.transform = transform