Adding your own dataset#
Transforms in DIRECT
currently support only gridded data (data acquired on an equispaced grid).
Any compatible dataset should inherit from PyTorch’s dataset class torch.utils.data.Dataset
.
Follow the steps below:
Implement your custom dataset under
direct/data/datasets.py
following the template:
import pathlib
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
class MyNewDataset(Dataset):
"""
Information about the Dataset.
"""
def __init__(
self,
root: pathlib.Path,
transform: Optional[Callable] = None,
filenames_filter: Optional[List[PathOrString]] = None,
text_description: Optional[str] = None,
...
) -> None:
"""
Initialize the dataset.
Parameters
----------
root : pathlib.Path
Root directory to saved data.
transform : Optional[Callable]
Callable function that transforms the loaded data.
filenames_filter : List
List of filenames to include in the dataset.
text_description : str
Description of dataset, can be useful for logging.
...
...
"""
super().__init__()
self.logger = logging.getLogger(type(self).__name__)
self.root = root
self.transform = transform
if filenames_filter:
self.logger.info(f"Attempting to load {len(filenames_filter)} filenames from list.")
filenames = filenames_filter
else:
self.logger.info(f"Parsing directory {self.root} for <data_type> files.")
filenames = list(self.root.glob("*.<data_type>"))
self.filenames_filter = filenames_filter
self.text_description = text_description
self.ndim = # 2 or 3
self.volume_indices = self.set_volume_indices(...)
...
def set_volume_indices(self, ...):
...
def get_dataset_len(self):
...
def __len__(self):
return self.get_dataset_len()
def __getitem__(self, idx: int) -> Dict[str, Any]:
...
sample = ...
...
if self.transform:
sample = self.transform(sample)
return sample
Note that the __getitem__
method should output dictionaries which contain keys with values either torch.Tensors or
other metadata. Current implemented models and transforms can work with multi-coil two-dimensional data. Therefore, new datasets
should split three-dimensional data to slices of two-dimensional data.
Register the new dataset in
direct/data/datasets_config.py
...
@dataclass
class MyDatasetConfig(BaseConfig):
...
name: str = "MyNew"
transforms: BaseConfig = TransformsConfig()
text_description: Optional[str] = None
...
To use your dataset, you have to request it in the
config.yaml
file. The following shows an example for training.
training:
datasets:
- name: MyNew
lists:
- <list_1_name>.lst
- <list_2_name>.lst
- ...
transforms:
...
masking:
...
...