direct package#
Subpackages#
- direct.algorithms package
- direct.cli package
- direct.common package
- Submodules
- direct.common.subsample module
BaseMaskFuncCIRCUSMaskFuncCIRCUSSamplingModeCalgaryCampinasMaskFuncCartesianEquispacedMaskFuncCartesianMagicMaskFuncCartesianRandomMaskFuncCartesianVerticalMaskFuncEquispacedMaskFuncFastMRIEquispacedMaskFuncFastMRIMagicMaskFuncFastMRIRandomMaskFuncGaussian1DMaskFuncGaussian2DMaskFuncKtBaseMaskFuncKtGaussian1DMaskFuncKtRadialMaskFuncKtUniformMaskFuncMagicMaskFuncMaskFuncModeRadialMaskFuncRandomMaskFuncSpiralMaskFuncVariableDensityPoissonMaskFuncbuild_masking_function()centered_disk_mask()integerize_seed()
- direct.common.subsample_config module
- Module contents
- direct.config package
- Submodules
- direct.config.defaults module
CheckpointerConfigDefaultConfigFunctionConfigInferenceConfigLoggingConfigLossConfigModelConfigPhysicsConfigTensorboardConfigTrainingConfigTrainingConfig.batch_sizeTrainingConfig.checkpointerTrainingConfig.datasetsTrainingConfig.gradient_clippingTrainingConfig.gradient_debugTrainingConfig.gradient_stepsTrainingConfig.lossTrainingConfig.lrTrainingConfig.lr_gammaTrainingConfig.lr_step_sizeTrainingConfig.lr_warmup_iterTrainingConfig.metricsTrainingConfig.model_checkpointTrainingConfig.num_iterationsTrainingConfig.optimizerTrainingConfig.regularizersTrainingConfig.swa_start_iterTrainingConfig.validation_stepsTrainingConfig.weight_decay
ValidationConfig
- Module contents
- direct.data package
- Submodules
- direct.data.bbox module
- direct.data.datasets module
- direct.data.datasets_config module
AugmentationTransformConfigCMRxReconConfigCalgaryCampinasConfigCropTransformConfigDatasetConfigFakeMRIBlobsConfigFastMRIConfigH5SliceConfigNormalizationTransformConfigRandomAugmentationTransformsConfigSensitivityMapEstimationTransformConfigSensitivityMapEstimationTransformConfig.estimate_sensitivity_mapsSensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_cropSensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_kernel_sizeSensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_max_itersSensitivityMapEstimationTransformConfig.sensitivity_maps_espirit_thresholdSensitivityMapEstimationTransformConfig.sensitivity_maps_gaussianSensitivityMapEstimationTransformConfig.sensitivity_maps_type
SheppLoganDatasetConfigSheppLoganProtonConfigSheppLoganT1ConfigSheppLoganT2ConfigTransformsConfigTransformsConfig.augmentationTransformsConfig.compress_coilsTransformsConfig.croppingTransformsConfig.delete_acs_maskTransformsConfig.delete_kspaceTransformsConfig.estimate_body_coil_imageTransformsConfig.image_recon_typeTransformsConfig.mask_split_acs_regionTransformsConfig.mask_split_gaussian_stdTransformsConfig.mask_split_half_directionTransformsConfig.mask_split_keep_acsTransformsConfig.mask_split_ratioTransformsConfig.mask_split_typeTransformsConfig.maskingTransformsConfig.normalizationTransformsConfig.pad_coilsTransformsConfig.padding_epsTransformsConfig.random_augmentationsTransformsConfig.sensitivity_map_estimationTransformsConfig.transforms_typeTransformsConfig.use_seed
- direct.data.fake module
- direct.data.h5_data module
- direct.data.lr_scheduler module
- direct.data.mri_transforms module
AddBooleanKeysModuleApplyMaskModuleApplyZeroPaddingComposeCompressCoilModuleComputeImageModuleComputeScalingFactorModuleComputeZeroPaddingCreateSamplingMaskCropKspaceDeleteKeysModuleEstimateBodyCoilImageEstimateSensitivityMapModuleModuleWrapperNormalizeModulePadCoilDimensionModulePadKspaceRandomFlipRandomFlipTypeRandomReverseRandomRotationReconstructionTypeRenameKeysModuleRescaleKspaceRescaleModeSensitivityMapTypeToTensorTransformsTypeWhitenDataModulebuild_mri_transforms()build_post_mri_transforms()build_pre_mri_transforms()build_supervised_mri_transforms()
- direct.data.samplers module
- direct.data.sens module
- direct.data.transforms module
apply_mask()apply_padding()center_crop()complex_bmm()complex_center_crop()complex_division()complex_dot_product()complex_image_resize()complex_mm()complex_multiplication()complex_random_crop()conjugate()crop_to_acs()expand_operator()fft2()fftshift()ifft2()ifftshift()modulus()modulus_if_complex()pad_tensor()reduce_operator()roll()roll_one_dim()root_sum_of_squares()safe_divide()tensor_to_complex_numpy()to_tensor()verify_fft_dtype_possible()view_as_complex()view_as_real()
- Module contents
- direct.functionals package
- Submodules
- direct.functionals.challenges module
- direct.functionals.grad module
- direct.functionals.hfen module
- direct.functionals.nmae module
- direct.functionals.nmse module
- direct.functionals.psnr module
- direct.functionals.snr module
- direct.functionals.ssim module
- Module contents
- direct.nn package
- Subpackages
- direct.nn.cirim package
- direct.nn.conjgradnet package
- direct.nn.conv package
- direct.nn.crossdomain package
- direct.nn.didn package
- direct.nn.iterdualnet package
- direct.nn.jointicnet package
- direct.nn.kikinet package
- direct.nn.lpd package
- direct.nn.mobilenet package
- direct.nn.multidomainnet package
- direct.nn.mwcnn package
- direct.nn.recurrent package
- direct.nn.recurrentvarnet package
- direct.nn.resnet package
- direct.nn.rim package
- direct.nn.ssl package
- direct.nn.transformers package
- direct.nn.unet package
- direct.nn.varnet package
- direct.nn.varsplitnet package
- direct.nn.vsharp package
- direct.nn.xpdnet package
- Submodules
- direct.nn.get_nn_model_config module
- direct.nn.mri_models module
- direct.nn.types module
ActivationTypeInitTypeLossFunTypeLossFunType.GRAD_L1_LOSSLossFunType.GRAD_L2_LOSSLossFunType.HFEN_L1_LOSSLossFunType.HFEN_L1_NORM_LOSSLossFunType.HFEN_L2_LOSSLossFunType.HFEN_L2_NORM_LOSSLossFunType.KSPACE_L1_LOSSLossFunType.KSPACE_L2_LOSSLossFunType.KSPACE_NMAE_LOSSLossFunType.KSPACE_NMSE_LOSSLossFunType.KSPACE_NRMSE_LOSSLossFunType.L1_LOSSLossFunType.L2_LOSSLossFunType.NMAE_LOSSLossFunType.NMSE_LOSSLossFunType.NRMSE_LOSSLossFunType.PSNR_LOSSLossFunType.SNR_LOSSLossFunType.SSIM_3D_LOSSLossFunType.SSIM_LOSS
ModelName
- Module contents
- Subpackages
- direct.ssl package
- direct.utils package
- Submodules
- direct.utils.asserts module
- direct.utils.bbox module
- direct.utils.communication module
- direct.utils.dataset module
- direct.utils.events module
CommonMetricPrinterEventStorageEventStorage.add_graph()EventStorage.add_image()EventStorage.add_scalar()EventStorage.add_scalars()EventStorage.clear_images()EventStorage.histories()EventStorage.history()EventStorage.iterEventStorage.latest()EventStorage.latest_with_smoothing_hint()EventStorage.name_scope()EventStorage.smoothing_hints()EventStorage.step()EventStorage.vis_data
EventWriterHistoryBufferJSONWriterTensorboardWriterget_event_storage()
- direct.utils.imports module
- direct.utils.io module
- direct.utils.logging module
- direct.utils.models module
- direct.utils.writers module
- Module contents
DirectModuleDirectTransformSpatialDimscast_as_path()chunks()count_parameters()detach_dict()dict_flatten()dict_to_device()ensure_list()evaluate_dict()git_hash()is_complex_data()is_power_of_two()merge_list_of_dicts()multiply_function()normalize_image()prefix_dict_keys()reduce_list_of_dicts()remove_keys()set_all_seeds()str_to_class()
Submodules#
direct.checkpointer module#
Checkpointer module.
Handles all logic related to checkpointing.
- class direct.checkpointer.Checkpointer(save_directory, save_to_disk=True, model_regex='^.*model$', **checkpointables)[source][source]#
Bases:
objectMain Checkpointer module.
Handles writing and restoring from checkpoints of modules and submodules.
- load_from_path(checkpoint_path, checkpointable_objects=None, only_models=False)[source][source]#
Load a checkpoint from a path.
- Parameters:
- checkpoint_path: Path or str
Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded
- checkpointable_objects: dict
Dictionary mapping names to nn.Module’s
- only_models: bool
If true will only load the models and no other objects in the checkpoint
- Returns:
- Dictionary with loaded models.
- Return type:
Dict
direct.constants module#
direct.engine module#
Main engine of DIRECT.
Implements all the main training, testing and validation logic.
- class direct.engine.DoIterationOutput(output_image, sensitivity_map, data_dict)[source]#
Bases:
tuple- data_dict#
Alias for field number 2
- output_image#
Alias for field number 0
- sensitivity_map#
Alias for field number 1
- class direct.engine.Engine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
ABC,DataDimensionality- static build_batch_sampler(dataset, batch_size, sampler_type, **kwargs)[source][source]#
- Return type:
Sampler
- static build_loader(dataset, batch_sampler=None, num_workers=6)[source][source]#
- Return type:
DataLoader
- predict(dataset, experiment_directory, checkpoint=-1, num_workers=6, batch_size=1, crop=None)[source][source]#
- Return type:
List[ndarray]
- train(optimizer, lr_scheduler, training_datasets, experiment_directory, validation_datasets=None, resume=False, start_with_validation=False, initialization=None, num_workers=6)[source][source]#
- Return type:
None
- training_loop(training_datasets, start_iter, validation_datasets=None, experiment_directory=None, num_workers=6, start_with_validation=False)[source][source]#
direct.environment module#
- class direct.environment.Args(epilog=None, add_help=True, **overrides)[source][source]#
Bases:
ArgumentParserDefines global default arguments.
- direct.environment.build_operators(cfg)[source][source]#
Builds operators from configuration.
- Return type:
Tuple[Callable,Callable]
- direct.environment.collect_env_info()[source][source]#
Collects environment information.
- Returns:
- env_info: str
Environment information as a formatted string.
- Return type:
str
- direct.environment.initialize_models_from_config(cfg, models, forward_operator, backward_operator, device)[source][source]#
Creates models from config.
- Parameters:
- cfg: DictConfig
Configuration.
- models: dict
Models dictionary including configurations.
- forward_operator: Callable
Forward operator.
- backward_operator: Callable
Backward operator.
- device: str
Type of device.
- Returns:
- model: torch.nn.Module
Model.
- additional_models: Dict
Additional models.
- Return type:
Tuple[Module,Dict]
- direct.environment.load_dataset_config(dataset_name)[source][source]#
Load specific dataset configuration for dataset based on dataset_name.
- Parameters:
- dataset_name: str
Name of dataset.
- Returns:
- dataset_config: Callable
Dataset configuration.
- Return type:
Callable
- direct.environment.load_model_config_from_name(model_name)[source][source]#
Load specific configuration module for models based on their name.
- Parameters:
- model_name: str
Path to model relative to direct.nn.
- Returns:
- model_cfg: Callable
Model configuration.
- Return type:
Callable
- direct.environment.load_model_from_name(model_name)[source][source]#
Load model based on model_name.
- Parameters:
- model_name: str
Model name as in direct.nn.
- Returns:
- model: Callable
Model class.
- Return type:
Callable
- direct.environment.load_models_into_environment_config(cfg_from_file)[source][source]#
Load the configuration for the models.
- Parameters:
- cfg_from_file: DictConfig
Omegaconf configuration.
- Returns:
- (models, models_config): (dict, DictConfig)
Models dictionary and models configuration dictionary.
- Return type:
Tuple[dict,DictConfig]
- direct.environment.setup_common_environment(run_name, base_directory, cfg_pathname, device, machine_rank, mixed_precision, debug=False)[source][source]#
Setup environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- cfg_pathname: Union[pathlib.Path, str]
Path or url to configuratio file.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Common Environment.
- direct.environment.setup_engine(cfg, device, model, additional_models, forward_operator=None, backward_operator=None, mixed_precision=False)[source][source]#
Setups engine.
- Parameters:
- cfg: DictConfig
Configuration.
- device: str
Type of device.
- model: torch.nn.Module
Model.
- additional_models: dict
Additional models.
- forward_operator: Callable
Forward operator.
- backward_operator: Callable
Backward operator.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- Returns:
- engine
Experiment Engine.
- direct.environment.setup_inference_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_file=None, debug=False)[source][source]#
Setup inference environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- cfg_file: Union[pathlib.Path, str], optional
Path or url to configuration file.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Inference Environment.
- direct.environment.setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, debug)[source][source]#
Logs environment information.
- Parameters:
- machine_rank: int
Machine rank.
- output_directory: pathlib.Path
Path to output directory.
- run_name: str
Name of run.
- cfg_filename: Union[pathlib.Path, str]
Name of configuration file.
- cfg: DefaultConfig
Configuration file.
- debug: bool
Whether the debug mode is enabled.
- Return type:
None
- direct.environment.setup_testing_environment(run_name, base_directory, device, machine_rank, mixed_precision, cfg_pathname=None, debug=False)[source][source]#
Setup testing environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- cfg_pathname: Union[pathlib.Path, str], optional
Path or url to configuration file.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Testing Environment.
- direct.environment.setup_training_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False)[source][source]#
Setup training environment.
- Parameters:
- run_name: str
Run name.
- base_directory: pathlib.Path
Base directory path.
- cfg_filename: Union[pathlib.Path, str]
Path or url to configuratio file.
- device: str
Device type.
- machine_rank: int
Machine rank.
- mixed_precision: bool
Whether to enable mixed precision or not. Default: False.
- debug: bool
Whether the debug mode is enabled.
- Returns:
- environment
Training Environment.
direct.exceptions module#
- exception direct.exceptions.ItemNotFoundException(item_name, message=None)[source][source]#
Bases:
DirectException
- exception direct.exceptions.ProcessKilledException(signal_id, signal_name)[source][source]#
Bases:
DirectExceptionThe process received SIGINT signal.
- exception direct.exceptions.TrainingException(message=None)[source][source]#
Bases:
DirectException
direct.inference module#
- direct.inference.build_inference_transforms(env, mask_func, dataset_cfg)[source][source]#
Builds inference transforms.
- Return type:
Callable
- direct.inference.inference_on_environment(env, data_root, dataset_cfg, transforms, experiment_path, checkpoint, num_workers=0, filenames_filter=None, batch_size=1, crop=None)[source][source]#
Performs inference on environment.
- Parameters:
- env: Environment.
- data_root: Union[PathOrString, None]
Path of the directory of the data if applicable for dataset. Can be None.
- dataset_cfg: DictConfig
Configuration containing inference dataset settings.
- transforms: Callable
Dataset transformations object.
- experiment_path: PathOrString
Path to directory where where inference logs will be stored.
- checkpoint: FileOrUrl
Checkpoint to a model. This can be a path to a local file or an URL.
- num_workers: int
Number of workers.
- filenames_filter: Union[List[PathOrString], None]
List of filenames to include in the dataset (if applicable). Can be None. Default: None.
- batch_size: int
Inference batch size. Default: 1.
- crop: Optional[str]
Inference crop type. Can be header or None. Default: None.
- Returns:
- output: Union[Dict, DefaultDict]
- Return type:
Union[Dict,DefaultDict]
- direct.inference.setup_inference_save_to_h5(get_inference_settings, run_name, data_root, base_directory, output_directory, filenames_filter, checkpoint, device, num_workers, machine_rank, cfg_file=None, process_per_chunk=None, mixed_precision=False, debug=False, is_validation=False)[source][source]#
This function contains most of the logic in DIRECT required to launch a multi-gpu / multi-node inference process.
It saves predictions as .h5 files.
- Parameters:
- get_inference_settings: Callable
Callable object to create inference dataset and environment.
- run_name: str
Experiment run name. Can be an empty string.
- data_root: Union[PathOrString, None]
Path of the directory of the data if applicable for dataset. Can be None.
- base_directory: PathOrString
Path to directory where where inference logs will be stored. If run_name is not an empty string, base_directory / run_name will be used.
- output_directory: PathOrString
Path to directory where output data will be saved.
- filenames_filter: Union[List[PathOrString], None]
List of filenames to include in the dataset (if applicable). Can be None.
- checkpoint: FileOrUrl
Checkpoint to a model. This can be a path to a local file or an URL.
- device: str
Device name.
- num_workers: int
Number of workers.
- machine_rank: int
Machine rank.
- cfg_file: Union[PathOrString, None]
Path to configuration file. If None, will search in base_directory.
- process_per_chunk: int
Processes per chunk number.
- mixed_precision: bool
If True, mixed precision will be allowed. Default: False.
- debug: bool
If True, debug information will be displayed. Default: False.
- is_validation: bool
If True, will use settings (e.g. batch_size & crop) of validation in config. Otherwise it will use inference settings. Default: False.
- Returns:
- None
- Return type:
None
direct.launch module#
- direct.launch.launch(func, num_machines, num_gpus, machine_rank, dist_url, *args)[source][source]#
Launch the training, in case there is only one GPU available the function can be called directly.
- Parameters:
- func: Callable
Function to launch.
- num_machinesint
The number of machines.
- num_gpus: int
The number of GPUs.
- machine_rank: int
The machine rank.
- dist_url: str
URL to connect to for distributed training, including protocol.
- args: Tuple
Arguments to pass to func.
- Return type:
None
- direct.launch.launch_distributed(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url='auto', args=(), timeout=datetime.timedelta(seconds=1800))[source][source]#
Launch multi-gpu or distributed training.
This function must be called on all machines involved in the training and it will spawn child processes (defined by num_gpus_per_machine) on each machine.
- Parameters:
- main_func: Callable
A function that will be called by main_func(*args).
- num_gpus_per_machine: int
The number of GPUs per machine.
- num_machinesint
The number of machines.
- machine_rank: int
The rank of this machine (one per machine).
- dist_url: str
URL to connect to for distributed training, including protocol e.g. “tcp://127.0.0.1:8686”. Can be set to auto to automatically select a free port on localhost
- args: Tuple
arguments passed to main_func.
- timeout: timedelta
Timeout of the distributed workers.
- Return type:
None
direct.predict module#
direct.train module#
- direct.train.build_training_datasets_from_environment(env, datasets_config, lists_root=None, data_root=None, initial_images=None, initial_kspaces=None, pass_text_description=True, pass_dictionaries=None)[source][source]#
- direct.train.build_transforms_from_environment(env, dataset_config)[source][source]#
- Return type:
Callable
- direct.train.get_root_of_file(filename)[source][source]#
Get the root directory of the file or URL to file.
- Parameters:
- filename: pathlib.Path or str
- Returns:
- pathlib.Path or str
Examples
>>> get_root_of_file('/mnt/archive/data.txt') >>> /mnt/archive >>> get_root_of_file('https://aiforoncology.nl/people') >>> https://aiforoncology.nl/
direct.types module#
direct.types module.
- class direct.types.DirectEnum(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
str,EnumType of any enumerator with allowed comparison to string invariant to cases.
- classmethod from_str(value)[source][source]#
- Return type:
DirectEnum|None
- class direct.types.IntegerListOrTupleString(string)[source][source]#
Bases:
objectIntegerListOrTupleString class represents a list or tuple of integers based on a string representation.
Examples
s1 = “[1, 2, 45, -1, 0]” print(isinstance(s1, IntegerListOrTupleString)) # True print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0] print(type(IntegerListOrTupleString(s1))) # <class ‘list’> print(type(IntegerListOrTupleString(s1)[0])) # <class ‘int’>
s2 = “(10, -9, 20)” print(isinstance(s2, IntegerListOrTupleString)) # True print(IntegerListOrTupleString(s2)) # (10, -9, 20) print(type(IntegerListOrTupleString(s2))) # <class ‘tuple’> print(type(IntegerListOrTupleString(s2)[0])) # <class ‘int’>
s3 = “[a, we, 2]” print(isinstance(s3, IntegerListOrTupleString)) # False
s4 = “(1, 2, 3]” print(isinstance(s4 IntegerListOrTupleString)) # False
- class direct.types.IntegerListOrTupleStringMeta[source][source]#
Bases:
typeMetaclass for the
IntegerListOrTupleStringclass.- Returns:
- bool
True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
- class direct.types.KspaceKey(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum- KSPACE = 'kspace'#
- MASKED_KSPACE = 'masked_kspace'#
- class direct.types.MaskFuncMode(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum- DYNAMIC = 'dynamic'#
- MULTISLICE = 'multislice'#
- STATIC = 'static'#
- class direct.types.TransformKey(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum- ACS_MASK = 'acs_mask'#
- KSPACE = 'kspace'#
- MASKED_KSPACE = 'masked_kspace'#
- SAMPLING_MASK = 'sampling_mask'#
- SCALING_FACTOR = 'scaling_factor'#
- SENSITIVITY_MAP = 'sensitivity_map'#
- TARGET = 'target'#