direct package

Contents

direct package#

Subpackages#

Submodules#

direct.checkpointer module#

Checkpointer module.

Handles all logic related to checkpointing.

class direct.checkpointer.Checkpointer(save_directory, save_to_disk=True, model_regex='^.*model$', **checkpointables)[source]#

Bases: object

Main Checkpointer module.

Handles writing and restoring from checkpoints of modules and submodules.

__init__(save_directory, save_to_disk=True, model_regex='^.*model$', **checkpointables)[source]#
load(iteration, checkpointable_objects=None)[source]#
Return type:

Dict

load_from_path(checkpoint_path, checkpointable_objects=None, only_models=False)[source]#

Load a checkpoint from a path.

Parameters:
  • checkpoint_path (Union[Path, str]) – Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded.

  • checkpointable_objects (Optional[Dict[str, Module]]) – Dictionary mapping names to nn.Module’s. Default: None.

  • only_models (bool) – If True will only load the models and no other objects in the checkpoint. Default: False.

Return type:

Dict

Returns:

Dictionary with loaded models.

load_models_from_file(checkpoint_path)[source]#
Return type:

None

save(iteration, **kwargs)[source]#
Return type:

None

direct.constants module#

direct.engine module#

Main engine of DIRECT.

Implements all the main training, testing and validation logic.

class direct.engine.DoIterationOutput(output_image, sensitivity_map, data_dict)#

Bases: tuple

data_dict#

Alias for field number 2

output_image#

Alias for field number 0

sensitivity_map#

Alias for field number 1

class direct.engine.DataDimensionality[source]#

Bases: object

__init__()[source]#
property ndim#
class direct.engine.Engine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: ABC, DataDimensionality

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits Engine.

Parameters:
  • cfg (BaseConfig) – Configuration object.

  • model (Module) – Model.

  • device (str) – Device. Can be "cuda" or "cpu".

  • forward_operator (Optional[Callable]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

abstract build_loss()[source]#
Return type:

Dict

build_metrics(metrics_list)[source]#
Return type:

Dict

build_regularizers(regularizers_list)[source]#
Return type:

Dict

predict(dataset, experiment_directory, checkpoint=-1, num_workers=6, batch_size=1, crop=None)[source]#
Return type:

List[ndarray]

static build_loader(dataset, batch_sampler=None, num_workers=6)[source]#
Return type:

DataLoader

static build_batch_sampler(dataset, batch_size, sampler_type, **kwargs)[source]#
Return type:

Sampler

training_loop(training_datasets, start_iter, validation_datasets=None, experiment_directory=None, num_workers=6, start_with_validation=False)[source]#
validate_model_at_interval(func, iter_idx, total_iter)[source]#
checkpoint_model_at_interval(iter_idx, total_iter)[source]#
write_to_logs_at_interval(iter_idx, total_iter)[source]#
checkpoint_and_write_to_logs(iter_idx)[source]#
validation_loop(validation_datasets, loss_fns, experiment_directory, iter_idx, num_workers=6)[source]#
process_slices_for_visualization(visualize_slices, visualize_target)[source]#
models_training_mode()[source]#
models_validation_mode()[source]#
models_to_device()[source]#
train(optimizer, lr_scheduler, training_datasets, experiment_directory, validation_datasets=None, resume=False, start_with_validation=False, initialization=None, num_workers=6)[source]#
Return type:

None

abstract reconstruct_volumes(*args, **kwargs)[source]#
abstract evaluate(*args, **kwargs)[source]#
log_process(idx, total)[source]#
log_first_training_example_and_model(data)[source]#
write_to_logs()[source]#

direct.environment module#

direct.environment.resolve_cache_dir()[source]#
Return type:

Path

direct.environment.collect_env_info()[source]#

Collects environment information.

Return type:

str

Returns:

Environment information as a formatted string.

direct.environment.load_model_config_from_name(model_name)[source]#

Load specific configuration module for models based on their name.

Parameters:

model_name (str) – Path to model relative to direct.nn.

Return type:

Callable

Returns:

Model configuration.

direct.environment.load_model_from_name(model_name)[source]#

Load model based on model_name.

Parameters:

model_name (str) – Model name as in direct.nn.

Return type:

Callable

Returns:

Model class.

direct.environment.load_dataset_config(dataset_name)[source]#

Load specific dataset configuration for dataset based on dataset_name.

Parameters:

dataset_name (str) – Name of dataset.

Return type:

Callable

Returns:

Dataset configuration.

direct.environment.build_operators(cfg)[source]#

Builds operators from configuration.

Return type:

Tuple[Callable, Callable]

direct.environment.setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, debug)[source]#

Logs environment information.

Parameters:
  • machine_rank (int) – Machine rank.

  • output_directory (Path) – Path to output directory.

  • run_name (str) – Name of run.

  • cfg_filename (Union[Path, str]) – Name of configuration file.

  • cfg (DefaultConfig) – Configuration object.

  • debug (bool) – Whether the debug mode is enabled.

Return type:

None

direct.environment.load_models_into_environment_config(cfg_from_file)[source]#

Load the configuration for the models.

Parameters:

cfg_from_file (DictConfig) – Omegaconf configuration.

Return type:

Tuple[dict, DictConfig]

Returns:

Tuple of models dictionary and models configuration dictionary.

direct.environment.initialize_models_from_config(cfg, models, forward_operator, backward_operator, device)[source]#

Creates models from config.

Parameters:
  • cfg (DictConfig) – Configuration object.

  • models (dict) – Models dictionary including configurations.

  • forward_operator (Callable) – Forward operator.

  • backward_operator (Callable) – Backward operator.

  • device (str) – Type of device.

Return type:

Tuple[Module, Dict]

Returns:

Tuple of model and additional models dictionary.

direct.environment.setup_engine(cfg, device, model, additional_models, forward_operator=None, backward_operator=None, mixed_precision=False)[source]#

Setups engine.

Parameters:
  • cfg (DictConfig) – Configuration object.

  • device (str) – Type of device.

  • model (Module) – Model.

  • additional_models (dict) – Additional models.

  • forward_operator (Union[Callable, object, None]) – Forward operator. Default: None.

  • backward_operator (Union[Callable, object, None]) – Backward operator. Default: None.

  • mixed_precision (bool) – Whether to enable mixed precision or not. Default: False.

Returns:

Experiment Engine.

direct.environment.extract_names(cfg)[source]#
direct.environment.setup_common_environment(run_name, base_directory, cfg_pathname, device, machine_rank, mixed_precision, debug=False)[source]#

Setup environment.

Parameters:
  • run_name (str) – Run name.

  • base_directory (Path) – Base directory path.

  • cfg_pathname (Union[Path, str]) – Path or url to configuration file.

  • device (str) – Device type.

  • machine_rank (int) – Machine rank.

  • mixed_precision (bool) – Whether to enable mixed precision or not. Default: False.

  • debug (bool) – Whether the debug mode is enabled. Default: False.

Returns:

Common Environment.

direct.environment.setup_training_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False)[source]#

Setup training environment.

Parameters:
  • run_name (str) – Run name.

  • base_directory (Path) – Base directory path.

  • cfg_filename (Union[Path, str]) – Path or url to configuration file.

  • device (str) – Device type.

  • machine_rank (int) – Machine rank.

  • mixed_precision (bool) – Whether to enable mixed precision or not. Default: False.

  • debug (bool) – Whether the debug mode is enabled. Default: False.

Returns:

Training Environment.

direct.environment.setup_testing_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_pathname=None, debug=False)[source]#

Setup testing environment.

Parameters:
  • run_name (str) – Run name.

  • base_directory (Path) – Base directory path.

  • device (str) – Device type.

  • machine_rank (int) – Machine rank.

  • mixed_precision (bool) – Whether to enable mixed precision or not. Default: False.

  • cfg_pathname (Union[Path, str, None]) – Path or url to configuration file. Default: None.

  • debug (bool) – Whether the debug mode is enabled. Default: False.

Returns:

Testing Environment.

direct.environment.setup_inference_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_file=None, debug=False)[source]#

Setup inference environment.

Parameters:
  • run_name (str) – Run name.

  • base_directory (Path) – Base directory path.

  • device (str) – Device type.

  • machine_rank (int) – Machine rank.

  • mixed_precision (bool) – Whether to enable mixed precision or not. Default: False.

  • cfg_file (Union[Path, str, None]) – Path or url to configuration file. Default: None.

  • debug (bool) – Whether the debug mode is enabled. Default: False.

Returns:

Inference Environment.

class direct.environment.Args(epilog=None, add_help=True, **overrides)[source]#

Bases: ArgumentParser

Defines global default arguments.

__init__(epilog=None, add_help=True, **overrides)[source]#

Inits Args.

Parameters:
  • epilog – Text to display after the argument help. Default: None.

  • add_help – Add a -h/--help option to the parser. Default: True.

  • **overrides – Keyword arguments used to override default argument values.

direct.exceptions module#

exception direct.exceptions.DirectException(*args, **kwargs)[source]#

Bases: BaseException

__init__(*args, **kwargs)[source]#
exception direct.exceptions.ProcessKilledException(signal_id, signal_name)[source]#

Bases: DirectException

The process received SIGINT signal.

__init__(signal_id, signal_name)[source]#
Parameters:
  • signal_id (int) – str

  • signal_name (str) – str

exception direct.exceptions.TrainingException(message=None)[source]#

Bases: DirectException

__init__(message=None)[source]#
exception direct.exceptions.ItemNotFoundException(item_name, message=None)[source]#

Bases: DirectException

__init__(item_name, message=None)[source]#

direct.inference module#

direct.inference.setup_inference_save_to_h5(get_inference_settings, run_name, data_root, base_directory, output_directory, filenames_filter, checkpoint, device, num_workers, machine_rank, cfg_file=None, process_per_chunk=None, mixed_precision=False, debug=False, is_validation=False)[source]#

This function contains most of the logic in DIRECT required to launch a multi-gpu / multi-node inference process.

It saves predictions as .h5 files.

Parameters:
  • get_inference_settings (Callable) – Callable object to create inference dataset and environment.

  • run_name (str) – Experiment run name. Can be an empty string.

  • data_root (Union[Path, str, None]) – Path of the directory of the data if applicable for dataset. Can be None.

  • base_directory (Union[Path, str]) – Path to directory where where inference logs will be stored. If run_name is not an empty string, base_directory / run_name will be used.

  • output_directory (Union[Path, str]) – Path to directory where output data will be saved.

  • filenames_filter (Optional[List[Union[Path, str]]]) – List of filenames to include in the dataset (if applicable). Can be None.

  • checkpoint (NewType(FileOrUrl, Union[Path, str])) – Checkpoint to a model. This can be a path to a local file or an URL.

  • device (str) – Device name.

  • num_workers (int) – Number of workers.

  • machine_rank (int) – Machine rank.

  • cfg_file (Union[Path, str, None]) – Path to configuration file. If None, will search in base_directory. Default: None.

  • process_per_chunk (Optional[int]) – Processes per chunk number. Default: None.

  • mixed_precision (bool) – If True, mixed precision will be allowed. Default: False.

  • debug (bool) – If True, debug information will be displayed. Default: False.

  • is_validation (bool) – If True, will use settings (e.g. batch_size & crop) of validation in config. Otherwise it will use inference settings. Default: False.

Return type:

None

direct.inference.build_inference_transforms(env, mask_func, dataset_cfg)[source]#

Builds inference transforms.

Return type:

Callable

direct.inference.inference_on_environment(env, data_root, dataset_cfg, transforms, experiment_path, checkpoint, num_workers=0, filenames_filter=None, batch_size=1, crop=None)[source]#

Performs inference on environment.

Parameters:
  • env – Environment.

  • data_root (Union[Path, str, None]) – Path of the directory of the data if applicable for dataset. Can be None.

  • dataset_cfg (DictConfig) – Configuration containing inference dataset settings.

  • transforms (Callable) – Dataset transformations object.

  • experiment_path (Union[Path, str]) – Path to the directory where inference logs will be stored.

  • checkpoint (NewType(FileOrUrl, Union[Path, str])) – Checkpoint to a model. This can be a path to a local file or an URL.

  • num_workers (int) – Number of workers. Default: 0.

  • filenames_filter (Optional[List[Union[Path, str]]]) – List of filenames to include in the dataset (if applicable). Can be None. Default: None.

  • batch_size (int) – Inference batch size. Default: 1.

  • crop (Optional[str]) – Inference crop type. Can be "header" or None. Default: None.

Return type:

Union[Dict, DefaultDict]

Returns:

Output data.

direct.launch module#

direct.launch.launch(func, num_machines, num_gpus, machine_rank, dist_url, *args)[source]#

Launch the training, in case there is only one GPU available the function can be called directly.

Parameters:
  • func (Callable) – Function to launch.

  • num_machines (int) – The number of machines.

  • num_gpus (int) – The number of GPUs.

  • machine_rank (int) – The machine rank.

  • dist_url (str) – URL to connect to for distributed training, including protocol.

  • *args (Tuple) – Arguments to pass to func.

Return type:

None

direct.launch.launch_distributed(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url='auto', args=(), timeout=DEFAULT_TIMEOUT)[source]#

Launch multi-gpu or distributed training.

This function must be called on all machines involved in the training and it will spawn child processes (defined by num_gpus_per_machine) on each machine.

Parameters:
  • main_func (Callable) – A function that will be called by main_func(*args).

  • num_gpus_per_machine (int) – The number of GPUs per machine.

  • num_machines (int) – The number of machines. Default: 1.

  • machine_rank (int) – The rank of this machine (one per machine). Default: 0.

  • dist_url (str) – URL to connect to for distributed training, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to "auto" to automatically select a free port on localhost. Default: "auto".

  • args (Tuple) – Arguments passed to main_func. Default: ().

  • timeout (timedelta) – Timeout of the distributed workers. Default: DEFAULT_TIMEOUT.

Return type:

None

direct.predict module#

direct.predict.predict_from_argparse(args)[source]#

direct.train module#

direct.train.parse_noise_dict(noise_dict, percentile=1.0, multiplier=1.0)[source]#
direct.train.get_root_of_file(filename)[source]#

Get the root directory of the file or URL to file.

Parameters:

filename (Union[Path, str]) – File path or URL.

Returns:

Root directory path.

Example

>>> get_root_of_file('/mnt/archive/data.txt')
/mnt/archive
>>> get_root_of_file('https://aiforoncology.nl/people')
https://aiforoncology.nl/
direct.train.build_transforms_from_environment(env, dataset_config)[source]#
Return type:

Callable

direct.train.build_training_datasets_from_environment(env, datasets_config, lists_root=None, data_root=None, initial_images=None, initial_kspaces=None, pass_text_description=True, pass_dictionaries=None)[source]#
direct.train.setup_train(run_name, training_root, validation_root, base_directory, cfg_filename, force_validation, initialization_checkpoint, initial_images, initial_kspace, noise, device, num_workers, resume, machine_rank, mixed_precision, debug)[source]#
direct.train.train_from_argparse(args)[source]#

direct.types module#

direct.types module.

class direct.types.DirectEnum(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Type of any enumerator with allowed comparison to string invariant to cases.

classmethod from_str(value)[source]#
Return type:

DirectEnum | None

class direct.types.KspaceKey(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

KSPACE = 'kspace'#
MASKED_KSPACE = 'masked_kspace'#
class direct.types.TransformKey(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

SENSITIVITY_MAP = 'sensitivity_map'#
TARGET = 'target'#
KSPACE = 'kspace'#
MASKED_KSPACE = 'masked_kspace'#
SAMPLING_MASK = 'sampling_mask'#
ACS_MASK = 'acs_mask'#
SCALING_FACTOR = 'scaling_factor'#
class direct.types.MaskFuncMode(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

STATIC = 'static'#
DYNAMIC = 'dynamic'#
MULTISLICE = 'multislice'#
class direct.types.IntegerListOrTupleStringMeta[source]#

Bases: type

Metaclass for the IntegerListOrTupleString class.

Returns:

True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.

__instancecheck__(instance)[source]#

Check if the given instance is a valid representation of an IntegerListOrTupleString.

Parameters:
Returns:

True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.

class direct.types.IntegerListOrTupleString(string)[source]#

Bases: object

IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.

Example

>>> s1 = "[1, 2, 45, -1, 0]"
>>> print(isinstance(s1, IntegerListOrTupleString))  # True
>>> print(IntegerListOrTupleString(s1))  # [1, 2, 45, -1, 0]
>>> print(type(IntegerListOrTupleString(s1)))  # <class 'list'>
>>> print(type(IntegerListOrTupleString(s1)[0]))  # <class 'int'>
>>> s2 = "(10, -9, 20)"
>>> print(isinstance(s2, IntegerListOrTupleString))  # True
>>> print(IntegerListOrTupleString(s2))  # (10, -9, 20)
>>> print(type(IntegerListOrTupleString(s2)))  # <class 'tuple'>
>>> print(type(IntegerListOrTupleString(s2)[0]))  # <class 'int'>
>>> s3 = "[a, we, 2]"
>>> print(isinstance(s3, IntegerListOrTupleString))  # False
>>> s4 = "(1, 2, 3]"
>>> print(isinstance(s4 IntegerListOrTupleString))  # False
static __new__(cls, string)[source]#

Create a new instance of IntegerListOrTupleString based on the given string representation.

Parameters:

string – The string representation of the integer list or tuple.

Returns:

A new instance of IntegerListOrTupleString.

Module contents#