direct package

Contents

direct package#

Subpackages#

Submodules#

direct.checkpointer module#

Checkpointer module.

Handles all logic related to checkpointing.

class direct.checkpointer.Checkpointer(save_directory, save_to_disk=True, model_regex='^.*model$', **checkpointables)[source][source]#

Bases: object

Main Checkpointer module.

Handles writing and restoring from checkpoints of modules and submodules.

load(iteration, checkpointable_objects=None)[source][source]#
Return type:

Dict

load_from_path(checkpoint_path, checkpointable_objects=None, only_models=False)[source][source]#

Load a checkpoint from a path.

Parameters:
checkpoint_path: Path or str

Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded

checkpointable_objects: dict

Dictionary mapping names to nn.Module’s

only_models: bool

If true will only load the models and no other objects in the checkpoint

Returns:
Dictionary with loaded models.
Return type:

Dict

load_models_from_file(checkpoint_path)[source][source]#
Return type:

None

save(iteration, **kwargs)[source][source]#
Return type:

None

direct.constants module#

direct.engine module#

Main engine of DIRECT.

Implements all the main training, testing and validation logic.

class direct.engine.DataDimensionality[source][source]#

Bases: object

property ndim#
class direct.engine.DoIterationOutput(output_image, sensitivity_map, data_dict)[source]#

Bases: tuple

data_dict#

Alias for field number 2

output_image#

Alias for field number 0

sensitivity_map#

Alias for field number 1

class direct.engine.Engine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: ABC, DataDimensionality

static build_batch_sampler(dataset, batch_size, sampler_type, **kwargs)[source][source]#
Return type:

Sampler

static build_loader(dataset, batch_sampler=None, num_workers=6)[source][source]#
Return type:

DataLoader

abstract build_loss()[source][source]#
Return type:

Dict

build_metrics(metrics_list)[source][source]#
Return type:

Dict

build_regularizers(regularizers_list)[source][source]#
Return type:

Dict

checkpoint_and_write_to_logs(iter_idx)[source][source]#
checkpoint_model_at_interval(iter_idx, total_iter)[source][source]#
abstract evaluate(*args, **kwargs)[source][source]#
log_first_training_example_and_model(data)[source][source]#
log_process(idx, total)[source][source]#
models_to_device()[source][source]#
models_training_mode()[source][source]#
models_validation_mode()[source][source]#
predict(dataset, experiment_directory, checkpoint=-1, num_workers=6, batch_size=1, crop=None)[source][source]#
Return type:

List[ndarray]

process_slices_for_visualization(visualize_slices, visualize_target)[source][source]#
abstract reconstruct_volumes(*args, **kwargs)[source][source]#
train(optimizer, lr_scheduler, training_datasets, experiment_directory, validation_datasets=None, resume=False, start_with_validation=False, initialization=None, num_workers=6)[source][source]#
Return type:

None

training_loop(training_datasets, start_iter, validation_datasets=None, experiment_directory=None, num_workers=6, start_with_validation=False)[source][source]#
validate_model_at_interval(func, iter_idx, total_iter)[source][source]#
validation_loop(validation_datasets, loss_fns, experiment_directory, iter_idx, num_workers=6)[source][source]#
write_to_logs()[source][source]#
write_to_logs_at_interval(iter_idx, total_iter)[source][source]#

direct.environment module#

class direct.environment.Args(epilog=None, add_help=True, **overrides)[source][source]#

Bases: ArgumentParser

Defines global default arguments.

direct.environment.build_operators(cfg)[source][source]#

Builds operators from configuration.

Return type:

Tuple[Callable, Callable]

direct.environment.extract_names(cfg)[source][source]#
direct.environment.initialize_models_from_config(cfg, models, forward_operator, backward_operator, device)[source][source]#

Creates models from config.

Parameters:
cfg: DictConfig

Configuration.

models: dict

Models dictionary including configurations.

forward_operator: Callable

Forward operator.

backward_operator: Callable

Backward operator.

device: str

Type of device.

Returns:
model: torch.nn.Module

Model.

additional_models: Dict

Additional models.

Return type:

Tuple[Module, Dict]

direct.environment.load_dataset_config(dataset_name)[source][source]#

Load specific dataset configuration for dataset based on dataset_name.

Parameters:
dataset_name: str

Name of dataset.

Returns:
dataset_config: Callable

Dataset configuration.

Return type:

Callable

direct.environment.load_model_config_from_name(model_name)[source][source]#

Load specific configuration module for models based on their name.

Parameters:
model_name: str

Path to model relative to direct.nn.

Returns:
model_cfg: Callable

Model configuration.

Return type:

Callable

direct.environment.load_model_from_name(model_name)[source][source]#

Load model based on model_name.

Parameters:
model_name: str

Model name as in direct.nn.

Returns:
model: Callable

Model class.

Return type:

Callable

direct.environment.load_models_into_environment_config(cfg_from_file)[source][source]#

Load the configuration for the models.

Parameters:
cfg_from_file: DictConfig

Omegaconf configuration.

Returns:
(models, models_config): (dict, DictConfig)

Models dictionary and models configuration dictionary.

Return type:

Tuple[dict, DictConfig]

direct.environment.setup_common_environment(run_name, base_directory, cfg_pathname, device, machine_rank, mixed_precision, debug=False)[source][source]#

Setup environment.

Parameters:
run_name: str

Run name.

base_directory: pathlib.Path

Base directory path.

cfg_pathname: Union[pathlib.Path, str]

Path or url to configuratio file.

device: str

Device type.

machine_rank: int

Machine rank.

mixed_precision: bool

Whether to enable mixed precision or not. Default: False.

debug: bool

Whether the debug mode is enabled.

Returns:
environment

Common Environment.

direct.environment.setup_engine(cfg, device, model, additional_models, forward_operator=None, backward_operator=None, mixed_precision=False)[source][source]#

Setups engine.

Parameters:
cfg: DictConfig

Configuration.

device: str

Type of device.

model: torch.nn.Module

Model.

additional_models: dict

Additional models.

forward_operator: Callable

Forward operator.

backward_operator: Callable

Backward operator.

mixed_precision: bool

Whether to enable mixed precision or not. Default: False.

Returns:
engine

Experiment Engine.

direct.environment.setup_inference_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_file=None, debug=False)[source][source]#

Setup inference environment.

Parameters:
run_name: str

Run name.

base_directory: pathlib.Path

Base directory path.

device: str

Device type.

machine_rank: int

Machine rank.

mixed_precision: bool

Whether to enable mixed precision or not. Default: False.

cfg_file: Union[pathlib.Path, str], optional

Path or url to configuration file.

debug: bool

Whether the debug mode is enabled.

Returns:
environment

Inference Environment.

direct.environment.setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, debug)[source][source]#

Logs environment information.

Parameters:
machine_rank: int

Machine rank.

output_directory: pathlib.Path

Path to output directory.

run_name: str

Name of run.

cfg_filename: Union[pathlib.Path, str]

Name of configuration file.

cfg: DefaultConfig

Configuration file.

debug: bool

Whether the debug mode is enabled.

Return type:

None

direct.environment.setup_testing_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_pathname=None, debug=False)[source][source]#

Setup testing environment.

Parameters:
run_name: str

Run name.

base_directory: pathlib.Path

Base directory path.

device: str

Device type.

machine_rank: int

Machine rank.

mixed_precision: bool

Whether to enable mixed precision or not. Default: False.

cfg_pathname: Union[pathlib.Path, str], optional

Path or url to configuration file.

debug: bool

Whether the debug mode is enabled.

Returns:
environment

Testing Environment.

direct.environment.setup_training_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False)[source][source]#

Setup training environment.

Parameters:
run_name: str

Run name.

base_directory: pathlib.Path

Base directory path.

cfg_filename: Union[pathlib.Path, str]

Path or url to configuratio file.

device: str

Device type.

machine_rank: int

Machine rank.

mixed_precision: bool

Whether to enable mixed precision or not. Default: False.

debug: bool

Whether the debug mode is enabled.

Returns:
environment

Training Environment.

direct.exceptions module#

exception direct.exceptions.DirectException(*args, **kwargs)[source][source]#

Bases: BaseException

exception direct.exceptions.ItemNotFoundException(item_name, message=None)[source][source]#

Bases: DirectException

exception direct.exceptions.ProcessKilledException(signal_id, signal_name)[source][source]#

Bases: DirectException

The process received SIGINT signal.

exception direct.exceptions.TrainingException(message=None)[source][source]#

Bases: DirectException

direct.inference module#

direct.inference.build_inference_transforms(env, mask_func, dataset_cfg)[source][source]#

Builds inference transforms.

Return type:

Callable

direct.inference.inference_on_environment(env, data_root, dataset_cfg, transforms, experiment_path, checkpoint, num_workers=0, filenames_filter=None, batch_size=1, crop=None)[source][source]#

Performs inference on environment.

Parameters:
env: Environment.
data_root: Union[PathOrString, None]

Path of the directory of the data if applicable for dataset. Can be None.

dataset_cfg: DictConfig

Configuration containing inference dataset settings.

transforms: Callable

Dataset transformations object.

experiment_path: PathOrString

Path to directory where where inference logs will be stored.

checkpoint: FileOrUrl

Checkpoint to a model. This can be a path to a local file or an URL.

num_workers: int

Number of workers.

filenames_filter: Union[List[PathOrString], None]

List of filenames to include in the dataset (if applicable). Can be None. Default: None.

batch_size: int

Inference batch size. Default: 1.

crop: Optional[str]

Inference crop type. Can be header or None. Default: None.

Returns:
output: Union[Dict, DefaultDict]
Return type:

Union[Dict, DefaultDict]

direct.inference.setup_inference_save_to_h5(get_inference_settings, run_name, data_root, base_directory, output_directory, filenames_filter, checkpoint, device, num_workers, machine_rank, cfg_file=None, process_per_chunk=None, mixed_precision=False, debug=False, is_validation=False)[source][source]#

This function contains most of the logic in DIRECT required to launch a multi-gpu / multi-node inference process.

It saves predictions as .h5 files.

Parameters:
get_inference_settings: Callable

Callable object to create inference dataset and environment.

run_name: str

Experiment run name. Can be an empty string.

data_root: Union[PathOrString, None]

Path of the directory of the data if applicable for dataset. Can be None.

base_directory: PathOrString

Path to directory where where inference logs will be stored. If run_name is not an empty string, base_directory / run_name will be used.

output_directory: PathOrString

Path to directory where output data will be saved.

filenames_filter: Union[List[PathOrString], None]

List of filenames to include in the dataset (if applicable). Can be None.

checkpoint: FileOrUrl

Checkpoint to a model. This can be a path to a local file or an URL.

device: str

Device name.

num_workers: int

Number of workers.

machine_rank: int

Machine rank.

cfg_file: Union[PathOrString, None]

Path to configuration file. If None, will search in base_directory.

process_per_chunk: int

Processes per chunk number.

mixed_precision: bool

If True, mixed precision will be allowed. Default: False.

debug: bool

If True, debug information will be displayed. Default: False.

is_validation: bool

If True, will use settings (e.g. batch_size & crop) of validation in config. Otherwise it will use inference settings. Default: False.

Returns:
None
Return type:

None

direct.launch module#

direct.launch.launch(func, num_machines, num_gpus, machine_rank, dist_url, *args)[source][source]#

Launch the training, in case there is only one GPU available the function can be called directly.

Parameters:
func: Callable

Function to launch.

num_machinesint

The number of machines.

num_gpus: int

The number of GPUs.

machine_rank: int

The machine rank.

dist_url: str

URL to connect to for distributed training, including protocol.

args: Tuple

Arguments to pass to func.

Return type:

None

direct.launch.launch_distributed(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url='auto', args=(), timeout=datetime.timedelta(seconds=1800))[source][source]#

Launch multi-gpu or distributed training.

This function must be called on all machines involved in the training and it will spawn child processes (defined by num_gpus_per_machine) on each machine.

Parameters:
main_func: Callable

A function that will be called by main_func(*args).

num_gpus_per_machine: int

The number of GPUs per machine.

num_machinesint

The number of machines.

machine_rank: int

The rank of this machine (one per machine).

dist_url: str

URL to connect to for distributed training, including protocol e.g. “tcp://127.0.0.1:8686”. Can be set to auto to automatically select a free port on localhost

args: Tuple

arguments passed to main_func.

timeout: timedelta

Timeout of the distributed workers.

Return type:

None

direct.predict module#

direct.predict.predict_from_argparse(args)[source][source]#

direct.train module#

direct.train.build_training_datasets_from_environment(env, datasets_config, lists_root=None, data_root=None, initial_images=None, initial_kspaces=None, pass_text_description=True, pass_dictionaries=None)[source][source]#
direct.train.build_transforms_from_environment(env, dataset_config)[source][source]#
Return type:

Callable

direct.train.get_root_of_file(filename)[source][source]#

Get the root directory of the file or URL to file.

Parameters:
filename: pathlib.Path or str
Returns:
pathlib.Path or str

Examples

>>> get_root_of_file('/mnt/archive/data.txt')
>>> /mnt/archive
>>> get_root_of_file('https://aiforoncology.nl/people')
>>> https://aiforoncology.nl/
direct.train.parse_noise_dict(noise_dict, percentile=1.0, multiplier=1.0)[source][source]#
direct.train.setup_train(run_name, training_root, validation_root, base_directory, cfg_filename, force_validation, initialization_checkpoint, initial_images, initial_kspace, noise, device, num_workers, resume, machine_rank, mixed_precision, debug)[source][source]#
direct.train.train_from_argparse(args)[source][source]#

direct.types module#

direct.types module.

class direct.types.DirectEnum(value)[source][source]#

Bases: str, Enum

Type of any enumerator with allowed comparison to string invariant to cases.

classmethod from_str(value)[source][source]#
Return type:

DirectEnum | None

class direct.types.IntegerListOrTupleString(string)[source][source]#

Bases: object

IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.

Examples

s1 = “[1, 2, 45, -1, 0]” print(isinstance(s1, IntegerListOrTupleString)) # True print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0] print(type(IntegerListOrTupleString(s1))) # <class ‘list’> print(type(IntegerListOrTupleString(s1)[0])) # <class ‘int’>

s2 = “(10, -9, 20)” print(isinstance(s2, IntegerListOrTupleString)) # True print(IntegerListOrTupleString(s2)) # (10, -9, 20) print(type(IntegerListOrTupleString(s2))) # <class ‘tuple’> print(type(IntegerListOrTupleString(s2)[0])) # <class ‘int’>

s3 = “[a, we, 2]” print(isinstance(s3, IntegerListOrTupleString)) # False

s4 = “(1, 2, 3]” print(isinstance(s4 IntegerListOrTupleString)) # False

class direct.types.IntegerListOrTupleStringMeta[source][source]#

Bases: type

Metaclass for the IntegerListOrTupleString class.

Returns:
bool

True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.

class direct.types.KspaceKey(value)[source][source]#

Bases: DirectEnum

An enumeration.

KSPACE = 'kspace'#
MASKED_KSPACE = 'masked_kspace'#
class direct.types.MaskFuncMode(value)[source][source]#

Bases: DirectEnum

An enumeration.

DYNAMIC = 'dynamic'#
MULTISLICE = 'multislice'#
STATIC = 'static'#
class direct.types.TransformKey(value)[source][source]#

Bases: DirectEnum

An enumeration.

ACS_MASK = 'acs_mask'#
KSPACE = 'kspace'#
MASKED_KSPACE = 'masked_kspace'#
SAMPLING_MASK = 'sampling_mask'#
SCALING_FACTOR = 'scaling_factor'#
SENSITIVITY_MAP = 'sensitivity_map'#
TARGET = 'target'#

Module contents#