Source code for ablkit.learning.torch_dataset.prediction_dataset

"""
Implementation of PyTorch dataset class used for Prediction.

Copyright (c) 2024 LAMDA.  All rights reserved.
"""

from typing import Any, Callable, List, Tuple, Optional

import torch
from torch.utils.data import Dataset


[docs] class PredictionDataset(Dataset): """ Dataset used for prediction. Parameters ---------- X : List[Any] The input data. transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version. Defaults to None. """ def __init__(self, X: List[Any], transform: Optional[Callable[..., Any]] = None): if not isinstance(X, list): raise ValueError("X should be of type list.") self.X = X self.transform = transform def __len__(self) -> int: """ Return the length of the dataset. Returns ------- int The length of the dataset. """ return len(self.X) def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: """ Get the item at the given index. Parameters ---------- index : int The index of the item to get. Returns ------- Tuple[Any, torch.Tensor] A tuple containing the object and its label. """ if index >= len(self): raise ValueError("index range error") x = self.X[index] if self.transform is not None: x = self.transform(x) return x