direct package#
Subpackages#
- direct.algorithms package
- direct.cli package
- direct.common package
- Submodules
- direct.common.subsample module
BaseMaskFunc
CIRCUSMaskFunc
CIRCUSSamplingMode
CalgaryCampinasMaskFunc
CartesianEquispacedMaskFunc
CartesianMagicMaskFunc
CartesianRandomMaskFunc
CartesianVerticalMaskFunc
EquispacedMaskFunc
FastMRIEquispacedMaskFunc
FastMRIMagicMaskFunc
FastMRIRandomMaskFunc
Gaussian1DMaskFunc
Gaussian2DMaskFunc
KtBaseMaskFunc
KtGaussian1DMaskFunc
KtRadialMaskFunc
KtUniformMaskFunc
MagicMaskFunc
MaskFuncMode
RadialMaskFunc
RandomMaskFunc
SpiralMaskFunc
VariableDensityPoissonMaskFunc
build_masking_function()
centered_disk_mask()
integerize_seed()
- direct.common.subsample_config module
- Module contents
- direct.config package
- Submodules
- direct.config.defaults module
CheckpointerConfig
DefaultConfig
FunctionConfig
InferenceConfig
LoggingConfig
LossConfig
ModelConfig
PhysicsConfig
TensorboardConfig
TrainingConfig
TrainingConfig.batch_size
TrainingConfig.checkpointer
TrainingConfig.datasets
TrainingConfig.gradient_clipping
TrainingConfig.gradient_debug
TrainingConfig.gradient_steps
TrainingConfig.loss
TrainingConfig.lr
TrainingConfig.lr_gamma
TrainingConfig.lr_step_size
TrainingConfig.lr_warmup_iter
TrainingConfig.metrics
TrainingConfig.model_checkpoint
TrainingConfig.num_iterations
TrainingConfig.optimizer
TrainingConfig.regularizers
TrainingConfig.swa_start_iter
TrainingConfig.validation_steps
TrainingConfig.weight_decay
ValidationConfig
- Module contents
- direct.data package
- Submodules
- direct.data.bbox module
- direct.data.datasets module
- direct.data.datasets_config module
AugmentationTransformConfig
CMRxReconConfig
CalgaryCampinasConfig
CropTransformConfig
DatasetConfig
FakeMRIBlobsConfig
FastMRIConfig
H5SliceConfig
NormalizationTransformConfig
RandomAugmentationTransformsConfig
SensitivityMapEstimationTransformConfig
SensitivityMapEstimationTransformConfig.estimate_sensitivity_maps
SensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_crop
SensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_kernel_size
SensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_max_iters
SensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_threshold
SensitivityMapEstimationTransformConfig.sensitivity_maps_gaussian
SensitivityMapEstimationTransformConfig.sensitivity_maps_type
SheppLoganDatasetConfig
SheppLoganProtonConfig
SheppLoganT1Config
SheppLoganT2Config
TransformsConfig
TransformsConfig.augmentation
TransformsConfig.compress_coils
TransformsConfig.cropping
TransformsConfig.delete_acs_mask
TransformsConfig.delete_kspace
TransformsConfig.estimate_body_coil_image
TransformsConfig.image_recon_type
TransformsConfig.mask_split_acs_region
TransformsConfig.mask_split_gaussian_std
TransformsConfig.mask_split_half_direction
TransformsConfig.mask_split_keep_acs
TransformsConfig.mask_split_ratio
TransformsConfig.mask_split_type
TransformsConfig.masking
TransformsConfig.normalization
TransformsConfig.pad_coils
TransformsConfig.padding_eps
TransformsConfig.random_augmentations
TransformsConfig.sensitivity_map_estimation
TransformsConfig.transforms_type
TransformsConfig.use_seed
- direct.data.fake module
- direct.data.h5_data module
- direct.data.lr_scheduler module
- direct.data.mri_transforms module
AddBooleanKeysModule
ApplyMaskModule
ApplyZeroPadding
Compose
CompressCoilModule
ComputeImageModule
ComputeScalingFactorModule
ComputeZeroPadding
CreateSamplingMask
CropKspace
DeleteKeysModule
EstimateBodyCoilImage
EstimateSensitivityMapModule
ModuleWrapper
NormalizeModule
PadCoilDimensionModule
PadKspace
RandomFlip
RandomFlipType
RandomReverse
RandomRotation
ReconstructionType
RenameKeysModule
RescaleKspace
RescaleMode
SensitivityMapType
ToTensor
TransformsType
WhitenDataModule
build_mri_transforms()
build_post_mri_transforms()
build_pre_mri_transforms()
build_supervised_mri_transforms()
- direct.data.samplers module
- direct.data.sens module
- direct.data.transforms module
apply_mask()
apply_padding()
center_crop()
complex_bmm()
complex_center_crop()
complex_division()
complex_dot_product()
complex_image_resize()
complex_mm()
complex_multiplication()
complex_random_crop()
conjugate()
crop_to_acs()
expand_operator()
fft2()
fftshift()
ifft2()
ifftshift()
modulus()
modulus_if_complex()
pad_tensor()
reduce_operator()
roll()
roll_one_dim()
root_sum_of_squares()
safe_divide()
tensor_to_complex_numpy()
to_tensor()
verify_fft_dtype_possible()
view_as_complex()
view_as_real()
- Module contents
- direct.functionals package
- Submodules
- direct.functionals.challenges module
- direct.functionals.grad module
- direct.functionals.hfen module
- direct.functionals.nmae module
- direct.functionals.nmse module
- direct.functionals.psnr module
- direct.functionals.snr module
- direct.functionals.ssim module
- Module contents
- direct.nn package
- Subpackages
- direct.nn.cirim package
- direct.nn.conjgradnet package
- direct.nn.conv package
- direct.nn.crossdomain package
- direct.nn.didn package
- direct.nn.iterdualnet package
- direct.nn.jointicnet package
- direct.nn.kikinet package
- direct.nn.lpd package
- direct.nn.mobilenet package
- direct.nn.multidomainnet package
- direct.nn.mwcnn package
- direct.nn.recurrent package
- direct.nn.recurrentvarnet package
- direct.nn.resnet package
- direct.nn.rim package
- direct.nn.ssl package
- direct.nn.unet package
- direct.nn.varnet package
- direct.nn.varsplitnet package
- direct.nn.vsharp package
- direct.nn.xpdnet package
- Submodules
- direct.nn.get_nn_model_config module
- direct.nn.mri_models module
- direct.nn.types module
ActivationType
InitType
LossFunType
LossFunType.GRAD_L1_LOSS
LossFunType.GRAD_L2_LOSS
LossFunType.HFEN_L1_LOSS
LossFunType.HFEN_L1_NORM_LOSS
LossFunType.HFEN_L2_LOSS
LossFunType.HFEN_L2_NORM_LOSS
LossFunType.KSPACE_L1_LOSS
LossFunType.KSPACE_L2_LOSS
LossFunType.KSPACE_NMAE_LOSS
LossFunType.KSPACE_NMSE_LOSS
LossFunType.KSPACE_NRMSE_LOSS
LossFunType.L1_LOSS
LossFunType.L2_LOSS
LossFunType.NMAE_LOSS
LossFunType.NMSE_LOSS
LossFunType.NRMSE_LOSS
LossFunType.PSNR_LOSS
LossFunType.SNR_LOSS
LossFunType.SSIM_3D_LOSS
LossFunType.SSIM_LOSS
ModelName
- Module contents
- Subpackages
- direct.ssl package
- direct.utils package
- Submodules
- direct.utils.asserts module
- direct.utils.bbox module
- direct.utils.communication module
- direct.utils.dataset module
- direct.utils.events module
CommonMetricPrinter
EventStorage
EventStorage.add_graph()
EventStorage.add_image()
EventStorage.add_scalar()
EventStorage.add_scalars()
EventStorage.clear_images()
EventStorage.histories()
EventStorage.history()
EventStorage.iter
EventStorage.latest()
EventStorage.latest_with_smoothing_hint()
EventStorage.name_scope()
EventStorage.smoothing_hints()
EventStorage.step()
EventStorage.vis_data
EventWriter
HistoryBuffer
JSONWriter
TensorboardWriter
get_event_storage()
- direct.utils.imports module
- direct.utils.io module
- direct.utils.logging module
- direct.utils.models module
- direct.utils.writers module
- Module contents
DirectModule
DirectTransform
SpatialDims
cast_as_path()
chunks()
count_parameters()
detach_dict()
dict_flatten()
dict_to_device()
ensure_list()
evaluate_dict()
git_hash()
is_complex_data()
is_power_of_two()
merge_list_of_dicts()
multiply_function()
normalize_image()
prefix_dict_keys()
reduce_list_of_dicts()
remove_keys()
set_all_seeds()
str_to_class()
Submodules#
direct.checkpointer module#
Checkpointer module.
Handles all logic related to checkpointing.
- class direct.checkpointer.Checkpointer(save_directory, save_to_disk=True, model_regex='^.*model$', **checkpointables)[source][source]#
Bases:
object
Main Checkpointer module.
Handles writing and restoring from checkpoints of modules and submodules.
- load_from_path(checkpoint_path, checkpointable_objects=None, only_models=False)[source][source]#
Load a checkpoint from a path.
- Parameters:
- checkpoint_path: Path or str
Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded
- checkpointable_objects: dict
Dictionary mapping names to nn.Module’s
- only_models: bool
If true will only load the models and no other objects in the checkpoint
- Returns:
- Dictionary with loaded models.
- Return type:
Dict
direct.constants module#
direct.engine module#
Main engine of DIRECT.
Implements all the main training, testing and validation logic.
- class direct.engine.DoIterationOutput(output_image, sensitivity_map, data_dict)[source]#
Bases:
tuple
- data_dict#
Alias for field number 2
- output_image#
Alias for field number 0
- sensitivity_map#
Alias for field number 1
- class direct.engine.Engine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
ABC
,DataDimensionality
- static build_batch_sampler(dataset, batch_size, sampler_type, **kwargs)[source][source]#
- Return type:
Sampler
- static build_loader(dataset, batch_sampler=None, num_workers=6)[source][source]#
- Return type:
DataLoader
- predict(dataset, experiment_directory, checkpoint=-1, num_workers=6, batch_size=1, crop=None)[source][source]#
- Return type:
List
[ndarray
]
- train(optimizer, lr_scheduler, training_datasets, experiment_directory, validation_datasets=None, resume=False, start_with_validation=False, initialization=None, num_workers=6)[source][source]#
- Return type:
None
- training_loop(training_datasets, start_iter, validation_datasets=None, experiment_directory=None, num_workers=6, start_with_validation=False)[source][source]#
direct.environment module#
- class direct.environment.Args(epilog=None, add_help=True, **overrides)[source][source]#
Bases:
ArgumentParser
Defines global default arguments.
- direct.environment.build_operators(cfg)[source][source]#
Builds operators from configuration.
- Return type:
Tuple
[Callable
,Callable
]
- direct.environment.initialize_models_from_config(cfg, models, forward_operator, backward_operator, device)[source][source]#
Creates models from config.
- Parameters:
- cfg: DictConfig
Configuration.
- models: dict
Models dictionary including configurations.
- forward_operator: Callable
Forward operator.
- backward_operator: Callable
Backward operator.
- device: str
Type of device.
- Returns:
- model: torch.nn.Module
Model.
- additional_models: Dict
Additional models.
- Return type:
Tuple
[Module
,Dict
]
- direct.environment.load_dataset_config(dataset_name)[source][source]#
Load specific dataset configuration for dataset based on dataset_name.
- Parameters:
- dataset_name: str
Name of dataset.
- Returns:
- dataset_config: Callable
Dataset configuration.
- Return type:
Callable
- direct.environment.load_model_config_from_name(model_name)[source][source]#
Load specific configuration module for models based on their name.
- Parameters:
- model_name: str
Path to model relative to direct.nn.
- Returns:
- model_cfg: Callable
Model configuration.
- Return type:
Callable
- direct.environment.load_model_from_name(model_name)[source][source]#
Load model based on model_name.
- Parameters:
- model_name: str
Model name as in direct.nn.
- Returns:
- model: Callable
Model class.
- Return type:
Callable
- direct.environment.load_models_into_environment_config(cfg_from_file)[source][source]#
Load the configuration for the models.
- Parameters:
- cfg_from_file: DictConfig
Omegaconf configuration.
- Returns:
- (models, models_config): (dict, DictConfig)
Models dictionary and models configuration dictionary.
- Return type:
Tuple
[dict
,DictConfig
]
- direct.environment.setup_common_environment(run_name, base_directory, cfg_pathname, device, machine_rank, mixed_precision, debug=False)[source][source]#
Setup environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- cfg_pathname: Union[pathlib.Path, str]
Path or url to configuratio file.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Common Environment.
- direct.environment.setup_engine(cfg, device, model, additional_models, forward_operator=None, backward_operator=None, mixed_precision=False)[source][source]#
Setups engine.
- Parameters:
- cfg: DictConfig
Configuration.
- device: str
Type of device.
- model: torch.nn.Module
Model.
- additional_models: dict
Additional models.
- forward_operator: Callable
Forward operator.
- backward_operator: Callable
Backward operator.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- Returns:
- engine
Experiment Engine.
- direct.environment.setup_inference_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_file=None, debug=False)[source][source]#
Setup inference environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- cfg_file: Union[pathlib.Path, str], optional
Path or url to configuration file.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Inference Environment.
- direct.environment.setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, debug)[source][source]#
Logs environment information.
- Parameters:
- machine_rank: int
Machine rank.
- output_directory: pathlib.Path
Path to output directory.
- run_name: str
Name of run.
- cfg_filename: Union[pathlib.Path, str]
Name of configuration file.
- cfg: DefaultConfig
Configuration file.
- debug: bool
Whether the debug mode is enabled.
- Return type:
None
- direct.environment.setup_testing_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_pathname=None, debug=False)[source][source]#
Setup testing environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- cfg_pathname: Union[pathlib.Path, str], optional
Path or url to configuration file.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Testing Environment.
- direct.environment.setup_training_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False)[source][source]#
Setup training environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- cfg_filename: Union[pathlib.Path, str]
Path or url to configuratio file.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Training Environment.
direct.exceptions module#
- exception direct.exceptions.ItemNotFoundException(item_name, message=None)[source][source]#
Bases:
DirectException
- exception direct.exceptions.ProcessKilledException(signal_id, signal_name)[source][source]#
Bases:
DirectException
The process received SIGINT signal.
- exception direct.exceptions.TrainingException(message=None)[source][source]#
Bases:
DirectException
direct.inference module#
- direct.inference.build_inference_transforms(env, mask_func, dataset_cfg)[source][source]#
Builds inference transforms.
- Return type:
Callable
- direct.inference.inference_on_environment(env, data_root, dataset_cfg, transforms, experiment_path, checkpoint, num_workers=0, filenames_filter=None, batch_size=1, crop=None)[source][source]#
Performs inference on environment.
- Parameters:
- env: Environment.
- data_root: Union[PathOrString, None]
Path of the directory of the data if applicable for dataset. Can be None.
- dataset_cfg: DictConfig
Configuration containing inference dataset settings.
- transforms: Callable
Dataset transformations object.
- experiment_path: PathOrString
Path to directory where where inference logs will be stored.
- checkpoint: FileOrUrl
Checkpoint to a model. This can be a path to a local file or an URL.
- num_workers: int
Number of workers.
- filenames_filter: Union[List[PathOrString], None]
List of filenames to include in the dataset (if applicable). Can be None. Default: None.
- batch_size: int
Inference batch size. Default: 1.
- crop: Optional[str]
Inference crop type. Can be header or None. Default: None.
- Returns:
- output: Union[Dict, DefaultDict]
- Return type:
Union
[Dict
,DefaultDict
]
- direct.inference.setup_inference_save_to_h5(get_inference_settings, run_name, data_root, base_directory, output_directory, filenames_filter, checkpoint, device, num_workers, machine_rank, cfg_file=None, process_per_chunk=None, mixed_precision=False, debug=False, is_validation=False)[source][source]#
This function contains most of the logic in DIRECT required to launch a multi-gpu / multi-node inference process.
It saves predictions as .h5 files.
- Parameters:
- get_inference_settings: Callable
Callable object to create inference dataset and environment.
- run_name: str
Experiment run name. Can be an empty string.
- data_root: Union[PathOrString, None]
Path of the directory of the data if applicable for dataset. Can be None.
- base_directory: PathOrString
Path to directory where where inference logs will be stored. If run_name is not an empty string, base_directory / run_name will be used.
- output_directory: PathOrString
Path to directory where output data will be saved.
- filenames_filter: Union[List[PathOrString], None]
List of filenames to include in the dataset (if applicable). Can be None.
- checkpoint: FileOrUrl
Checkpoint to a model. This can be a path to a local file or an URL.
- device: str
Device name.
- num_workers: int
Number of workers.
- machine_rank: int
Machine rank.
- cfg_file: Union[PathOrString, None]
Path to configuration file. If None, will search in base_directory.
- process_per_chunk: int
Processes per chunk number.
- mixed_precision: bool
If True, mixed precision will be allowed. Default: False.
- debug: bool
If True, debug information will be displayed. Default: False.
- is_validation: bool
If True, will use settings (e.g. batch_size & crop) of validation in config. Otherwise it will use inference settings. Default: False.
- Returns:
- None
- Return type:
None
direct.launch module#
- direct.launch.launch(func, num_machines, num_gpus, machine_rank, dist_url, *args)[source][source]#
Launch the training, in case there is only one GPU available the function can be called directly.
- Parameters:
- func: Callable
Function to launch.
- num_machinesint
The number of machines.
- num_gpus: int
The number of GPUs.
- machine_rank: int
The machine rank.
- dist_url: str
URL to connect to for distributed training, including protocol.
- args: Tuple
Arguments to pass to func.
- Return type:
None
- direct.launch.launch_distributed(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url='auto', args=(), timeout=datetime.timedelta(seconds=1800))[source][source]#
Launch multi-gpu or distributed training.
This function must be called on all machines involved in the training and it will spawn child processes (defined by num_gpus_per_machine) on each machine.
- Parameters:
- main_func: Callable
A function that will be called by main_func(*args).
- num_gpus_per_machine: int
The number of GPUs per machine.
- num_machinesint
The number of machines.
- machine_rank: int
The rank of this machine (one per machine).
- dist_url: str
URL to connect to for distributed training, including protocol e.g. “tcp://127.0.0.1:8686”. Can be set to auto to automatically select a free port on localhost
- args: Tuple
arguments passed to main_func.
- timeout: timedelta
Timeout of the distributed workers.
- Return type:
None
direct.predict module#
direct.train module#
- direct.train.build_training_datasets_from_environment(env, datasets_config, lists_root=None, data_root=None, initial_images=None, initial_kspaces=None, pass_text_description=True, pass_dictionaries=None)[source][source]#
- direct.train.build_transforms_from_environment(env, dataset_config)[source][source]#
- Return type:
Callable
- direct.train.get_root_of_file(filename)[source][source]#
Get the root directory of the file or URL to file.
- Parameters:
- filename: pathlib.Path or str
- Returns:
- pathlib.Path or str
Examples
>>> get_root_of_file('/mnt/archive/data.txt') >>> /mnt/archive >>> get_root_of_file('https://aiforoncology.nl/people') >>> https://aiforoncology.nl/
direct.types module#
direct.types module.
- class direct.types.DirectEnum(value)[source][source]#
Bases:
str
,Enum
Type of any enumerator with allowed comparison to string invariant to cases.
- class direct.types.IntegerListOrTupleString(string)[source][source]#
Bases:
object
IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.
Examples
s1 = “[1, 2, 45, -1, 0]” print(isinstance(s1, IntegerListOrTupleString)) # True print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0] print(type(IntegerListOrTupleString(s1))) # <class ‘list’> print(type(IntegerListOrTupleString(s1)[0])) # <class ‘int’>
s2 = “(10, -9, 20)” print(isinstance(s2, IntegerListOrTupleString)) # True print(IntegerListOrTupleString(s2)) # (10, -9, 20) print(type(IntegerListOrTupleString(s2))) # <class ‘tuple’> print(type(IntegerListOrTupleString(s2)[0])) # <class ‘int’>
s3 = “[a, we, 2]” print(isinstance(s3, IntegerListOrTupleString)) # False
s4 = “(1, 2, 3]” print(isinstance(s4 IntegerListOrTupleString)) # False
- class direct.types.IntegerListOrTupleStringMeta[source][source]#
Bases:
type
Metaclass for the
IntegerListOrTupleString
class.- Returns:
- bool
True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
- class direct.types.KspaceKey(value)[source][source]#
Bases:
DirectEnum
An enumeration.
- KSPACE = 'kspace'#
- MASKED_KSPACE = 'masked_kspace'#