direct.utils package

Contents

direct.utils package#

Submodules#

direct.utils.asserts module#

direct.utils.asserts.assert_positive_integer(*variables, strict=False)[source]#

Assert if given variables are positive integer.

Parameters:
  • *variables – Variables to check.

  • strict (bool) – If True, will allow zero values. Default: False.

Return type:

None

direct.utils.asserts.assert_same_shape(data_list)[source]#

Check if all tensors in the list have the same shape.

Parameters:

data_list (List[Tensor]) – List of tensors.

direct.utils.asserts.assert_complex(data, complex_axis=-1, complex_last=None)[source]#

Assert if a tensor is complex (has complex dimension of size 2 corresponding to real and imaginary channels).

Parameters:
  • data (Tensor) – Tensor to check.

  • complex_axis (int) – Complex dimension along which the assertion will be done. Default: -1 (last).

  • complex_last (Optional[bool]) – If True, will override complex_axis with -1 (last). Default: None.

Return type:

None

direct.utils.bbox module#

direct.utils.bbox.crop_to_bbox(data, bbox, pad_value=0)[source]#

Extract bbox from images, coordinates can be negative.

Parameters:
  • data (Union[ndarray, Tensor]) – nD array or torch tensor.

  • bbox (List[int]) – Bounding box of the form (coordinates, size). For instance (4, 4, 2, 1) is a patch starting at row 4, col 4 with height 2 and width 1.

  • pad_value (int) – If bounding box would be out of the image, this is value the patch will be padded with. Default: 0.

Return type:

Union[ndarray, Tensor]

Returns:

Numpy array of data cropped to bounding box.

direct.utils.bbox.crop_to_largest(data, pad_value=0)[source]#

Given a list of arrays or tensors, return the same list with the data padded to the largest in the set.

Can be convenient for e.g. logging and tiling several images as with torchvision’s make_grid.

Parameters:
  • data (List[Union[ndarray, Tensor]]) – List of numpy arrays or torch tensors.

  • pad_value (int) – Padding value. Default: 0.

Return type:

List[Union[ndarray, Tensor]]

Returns:

List of padded arrays or tensors.

direct.utils.communication module#

direct.utils.communication.synchronize()[source]#

Synchronize processes between GPUs.

Wait until all devices are available. Does nothing in a non-distributed setting.

direct.utils.communication.get_rank()[source]#

Get rank of the process, even when torch.distributed is not initialized.

Return type:

int

Returns:

Process rank.

direct.utils.communication.get_local_rank()[source]#

Get rank of the process within the same machine, even when torch.distributed is not initialized.

Return type:

int

Returns:

The rank of the current process within the local (per-machine) process group.

direct.utils.communication.get_local_size()[source]#

Number of compute units in local machine.

Return type:

int

Returns:

Number of compute units.

direct.utils.communication.is_main_process()[source]#

Simple wrapper around get_rank().

Return type:

bool

Returns:

True if current process is the main process.

direct.utils.communication.get_world_size()[source]#

Get number of compute device in the world, returns 1 in case multi device is not initialized.

Return type:

int

Returns:

World size.

direct.utils.communication.all_gather(data, group=None)[source]#

Run all_gather on arbitrary picklable data (not necessarily tensors).

Parameters:
  • data (object) – object

  • object. (Any pickleable)

  • group (Optional[group])

  • default (A torch process group. By)

  • backend. (will use a group which contains all ranks on gloo)

Returns:

list of data gathered for each rank

Return type:

list

direct.utils.communication.gather(data, destination_rank=0, group=None)[source]#

Run gather on arbitrary picklable data (not necessarily tensors).

Parameters:
  • data (object) – object

  • object (Any pickleable)

  • destination_rank (int) – int

  • rank (Destination)

  • group (Optional[group])

  • default (A torch process group. By)

  • backend. (will use a group which contains all ranks on gloo)

Returns:

on destination_rank, a list of data gathered from each rank. Otherwise, an empty list.

Return type:

list[data]

direct.utils.communication.shared_random_seed()[source]#

All workers must call this function, otherwise it will deadlock.

Return type:

int

Returns:

A random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to create one.

direct.utils.communication.reduce_tensor_dict(tensors_dict)[source]#

Reduce the tensor dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as tensors_dict, after reduction.

Parameters:
  • tensors_dict (Dict[str, Tensor]) – dict

  • tensors. (dictionary with str keys mapping to torch)

Returns:

the reduced dict.

Return type:

dict

direct.utils.dataset module#

direct.utils.dataset.get_filenames_for_datasets_from_config(cfg, files_root, data_root)[source]#

Given a configuration object it returns a list of filenames.

Parameters:
  • cfg – Configuration object having property lists having the relative paths compared to files root.

  • files_root (Union[Path, str]) – Files root path.

  • data_root (Path) – Data root path.

Returns:

List of filenames or None.

direct.utils.dataset.get_filenames_for_datasets(lists, files_root, data_root)[source]#

Given lists of filenames of data points, concatenate these into a large list of full filenames.

Parameters:
  • lists (List[Union[Path, str]]) – List of filename list paths.

  • files_root (Union[Path, str]) – Files root path.

  • data_root (Path) – Data root path.

Returns:

List of filenames or None.

direct.utils.events module#

direct.utils.events.get_event_storage()[source]#

Get the current event storage.

Returns:

The EventStorage object that’s currently being used.

Raises:

ValueError – If no EventStorage is currently enabled.

class direct.utils.events.EventWriter[source]#

Bases: object

Base class for writers that obtain events from EventStorage and process them.

write()[source]#
close()[source]#
class direct.utils.events.JSONWriter(json_file, window_size=2)[source]#

Bases: EventWriter

Write scalars to a json file.

It saves scalars as one json per line (instead of a big json) for easy parsing.

Examples parsing such a json file:

$ cat metrics.json | jq -s '.[0:2]'
[
  {
    "data_time": 0.008433341979980469,
    "iteration": 20,
    "loss": 1.9228371381759644,
    "loss_box_reg": 0.050025828182697296,
    "loss_classifier": 0.5316952466964722,
    "loss_mask": 0.7236229181289673,
    "loss_rpn_box": 0.0856662318110466,
    "loss_rpn_cls": 0.48198649287223816,
    "lr": 0.007173333333333333,
    "time": 0.25401854515075684
  },
  {
    "data_time": 0.007216215133666992,
    "iteration": 40,
    "loss": 1.282649278640747,
    "loss_box_reg": 0.06222952902317047,
    "loss_classifier": 0.30682939291000366,
    "loss_mask": 0.6970193982124329,
    "loss_rpn_box": 0.038663312792778015,
    "loss_rpn_cls": 0.1471673548221588,
    "lr": 0.007706666666666667,
    "time": 0.2490077018737793
  }
]

$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
__init__(json_file, window_size=2)[source]#

Inits JSONWriter.

Parameters:
  • json_file (Union[Path, str]) – Path to the JSON file. Data will be appended if it exists.

  • window_size (int) – Window size of median smoothing for variables for which smoothing_hint is True. Default: 2.

write()[source]#
close()[source]#
class direct.utils.events.TensorboardWriter(log_dir, window_size=20, **kwargs)[source]#

Bases: EventWriter

Write all scalars to a tensorboard file.

__init__(log_dir, window_size=20, **kwargs)[source]#

Inits TensorboardWriter.

Parameters:
  • log_dir (Union[Path, str]) – The directory to save the output events.

  • window_size (int) – The scalars will be median-smoothed by this window size. Default: 20.

  • **kwargs – Other arguments passed to torch.utils.tensorboard.SummaryWriter.

write()[source]#
close()[source]#
class direct.utils.events.CommonMetricPrinter(max_iter)[source]#

Bases: EventWriter

Print common metrics to the terminal, including iteration time, ETA, memory, all losses, and the learning rate.

To print something different, please implement a similar printer by yourself.

__init__(max_iter)[source]#

Inits CommonMetricPrinter.

Parameters:

max_iter – The maximum number of iterations to train. Used to compute ETA.

write()[source]#
class direct.utils.events.EventStorage(start_iter=0)[source]#

Bases: object

The user-facing class that provides metric storage functionalities.

In the future we may add support for storing / logging other types of data if needed.

__init__(start_iter=0)[source]#
Parameters:
  • start_iter – int

  • with. (The index to start)

add_image(img_name, img_tensor)[source]#

Add an img_tensor to the _vis_data associated with img_name.

Parameters:
  • img_name – str

  • tensorboard. (The name of the input_image to put into)

  • img_tensor – torch.Tensor or numpy.array

  • `[channel (An uint8 or float Tensor of shape) –

  • height (float32) or [0, 255] (uint8)

  • [0 (width]` where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in)

  • 1] (float32) or [0, 255] (uint8)

clear_images()[source]#

Delete all the stored images for visualization.

This should be called after images are written to tensorboard.

add_scalar(name, value, smoothing_hint=True)[source]#

Add a scalar value to the HistoryBuffer associated with name.

Parameters:
  • name – str

  • value – float

  • smoothing_hint – bool

  • signal. (A 'hint' on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through EventStorage.smoothing_hints. A writer may ignore the hint and apply custom smoothing rule. It defaults to True because most scalars we save need to be smoothed to provide any useful)

Returns:

add_scalars(*, smoothing_hint=True, **kwargs)[source]#

Put multiple scalars from keyword arguments.

Examples

storage.add_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)

add_graph(img_name, img_tensor)[source]#

Add an img_tensor to the _vis_data associated with img_name.

Parameters:
  • img_name – str

  • tensorboard. (The name of the input_image to put into)

  • img_tensor – torch.Tensor or numpy.array

  • `[channel (An uint8 or float Tensor of shape) –

  • height (float32) or [0, 255] (uint8)

  • [0 (width]` where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in)

  • 1] (float32) or [0, 255] (uint8)

history(name)[source]#
Returns:

the scalar history for name

Return type:

HistoryBuffer

histories()[source]#
Returns:

the HistoryBuffer for all scalars

Return type:

dict[name -> HistoryBuffer]

latest()[source]#
Returns:

the scalars that’s added in the current iteration.

Return type:

dict[name -> number]

latest_with_smoothing_hint(window_size=20)[source]#

Similar to latest(), but the returned values are either the un-smoothed original latest value, or a median of the given window_size, depend on whether the smoothing_hint is True.

This provides a default behavior that other writers can use.

smoothing_hints()[source]#
Returns:

the user-provided hint on whether the scalar

is noisy and needs smoothing.

Return type:

dict[name -> bool]

step()[source]#

User should call this function at the beginning of each iteration, to notify the storage of the start of a new iteration.

The storage will then be able to associate the new data with the correct iteration number.

property vis_data#
property iter#
name_scope(name)[source]#
Yields:

A context within which all the events added to this storage will be prefixed by the name scope.

class direct.utils.events.HistoryBuffer(max_length=1000000)[source]#

Bases: object

Track a series of scalar values and provide access to smoothed values over a window or the global average of the series.

__init__(max_length=1000000)[source]#
Parameters:

max_length (int) – maximal number of values that can be stored in the buffer. When the capacity of the buffer is exhausted, old values will be removed.

update(value, iteration=None)[source]#

Add a new scalar value produced at certain iteration.

If the length of the buffer exceeds self._max_length, the oldest element will be removed from the buffer.

Return type:

None

latest()[source]#

Return the latest scalar value added to the buffer.

Return type:

float

median(window_size)[source]#

Return the median of the latest window_size values in the buffer.

Return type:

float

avg(window_size)[source]#

Return the mean of the latest window_size values in the buffer.

Return type:

float

global_avg()[source]#

Return the mean of all the elements in the buffer.

Note that this includes those getting removed due to limited buffer storage.

Return type:

float

values()[source]#
Returns:

content of the current buffer.

Return type:

list[(number, iteration)]

direct.utils.imports module#

General utilities for module imports.

direct.utils.io module#

direct.utils.io.read_json(fn)[source]#

Read file and output dict, or take dict and output dict.

Parameters:

fn (Union[Dict, str, Path]) – File path or dictionary.

Return type:

Dict

Returns:

Dictionary content.

class direct.utils.io.ArrayEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]#

Bases: JSONEncoder

default(obj)[source]#

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this:

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return super().default(o)
direct.utils.io.write_json(fn, data, indent=2)[source]#

Write dict data to file.

Parameters:
  • fn (Union[str, Path]) – File path.

  • data (Dict) – Dictionary to write.

  • indent – Indentation level. Default: 2.

Return type:

None

direct.utils.io.read_list(fn)[source]#

Read file and output list, or take list and output list. Can read data from URLs.

Parameters:

fn (Union[List, str, Path]) – Input text file or list, or a URL to a text file.

Return type:

List

Returns:

Text file read line by line.

direct.utils.io.write_list(fn, data)[source]#

Write list line by line to file.

Parameters:
  • fn (Union[str, Path]) – Output file path.

  • data – List or tuple to write.

Return type:

None

direct.utils.io.gen_bar_updater()[source]#
Return type:

Callable[[int, int, int], None]

direct.utils.io.calculate_md5(fpath, chunk_size=1024 * 1024)[source]#
Return type:

str

direct.utils.io.check_md5(fpath, md5, **kwargs)[source]#
Return type:

bool

direct.utils.io.check_integrity(fpath, md5=None)[source]#
Return type:

bool

direct.utils.io.download_url(url, root, filename=None, md5=None, max_redirect_hops=3)[source]#

Download a file from a url and place it in root.

Parameters:
  • url (str) – URL to download file from

  • root (Union[Path, str]) – Directory to place downloaded file in

  • filename (Optional[str]) – Name to save the file under. If None, use the basename of the URL

  • md5 (Optional[str]) – MD5 checksum of the download. If None, do not check

  • max_redirect_hops (int) – Maximum number of redirect hops allowed

Return type:

None

direct.utils.io.extract_archive(from_path, to_path=None, remove_finished=False)[source]#

Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed but not an archive the call is dispatched to decompress().

Parameters:
  • from_path (str) – Path to the file to be extracted.

  • to_path (Optional[str]) – Path to the directory the file will be extracted to. If omitted, the directory of the file is used.

  • remove_finished (bool) – If True, remove the file after the extraction.

Return type:

str

Returns:

Path to the directory the file was extracted to.

direct.utils.io.download_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False)[source]#
Return type:

None

direct.utils.io.read_text_from_url(url, chunk_size=1024)[source]#

Read a text file from a URL, e.g. a config file.

Parameters:
  • url – str

  • chunk_size (int) – int

Returns:

Data from URL

direct.utils.io.check_is_valid_url(path)[source]#

Check if the given path is a valid url.

Parameters:

path (Union[Path, str]) – PathOrString

Return type:

bool

Returns:

Bool describing if this is an URL or not.

direct.utils.io.upload_to_s3(filename, to_filename, aws_access_key_id, aws_secret_access_key, endpoint_url='https://s3.aiforoncology.nl', bucket='direct-project', verbose=True)[source]#

Upload file to an s3 bucket.

Parameters:
  • filename (Path) – Filename to upload

  • to_filename (str) – Where to store the file

  • aws_access_key_id (str) – str

  • aws_secret_access_key (str) – str

  • endpoint_url (str) – AWS endpoint url

  • bucket (str) – Bucket name

  • verbose (bool) – Show upload progress

Return type:

None

Returns:

None

direct.utils.logging module#

direct.utils.logging.setup(use_stdout=True, filename=None, log_level='INFO')[source]#

Setup logging for DIRECT.

Parameters:
  • use_stdout (Optional[bool]) – Write output to standard out. Default: True.

  • filename (Optional[PathLike]) – Filename to write log to. Default: None.

  • log_level (Union[int, str]) – Logging level as in the python.logging library. Default: "INFO".

Return type:

None

direct.utils.models module#

direct.utils.models.fix_state_dict_module_prefix(state_dict)[source]#

If models are saved after being wrapped in e.g. DataParallel, the keys of the state dict are prefixed with module..

This function removes this prefix.

Parameters:

state_dict – State dict of a network module.

Returns:

State dict with fixed keys.

direct.utils.writers module#

direct.utils.writers.write_output_to_h5(output, output_directory, volume_processing_func=None, output_key='reconstruction', create_dirs_if_needed=True)[source]#

Write dictionary with keys filenames and values torch tensors to h5 files.

Parameters:
  • output (Union[Dict, DefaultDict]) – Dictionary with keys filenames and values torch.Tensor’s with shape [depth, num_channels, ...] where num_channels is typically 1 for MRI.

  • output_directory (Path) – Output directory path.

  • volume_processing_func (Optional[Callable]) – Function which postprocesses the volume array before saving. Default: None.

  • output_key (str) – Name of key to save the output to. Default: "reconstruction".

  • create_dirs_if_needed (bool) – If True, the output directory and all its parents will be created. Default: True.

Return type:

None

Note

Currently only num_channels = 1 is supported. If you run this function with more channels the first one will be used.

Module contents#

direct.utils module.

direct.utils.is_complex_data(data, complex_axis=-1)[source]#

Returns True if data is a complex tensor at a specified dimension.

I.e. complex_axis of data is of size 2, corresponding to real and imaginary channels.

Parameters:
  • data (Tensor) – Input tensor.

  • complex_axis (int) – Complex dimension along which the check will be done. Default: -1 (last).

Return type:

bool

Returns:

True if data is a complex tensor.

direct.utils.is_power_of_two(number)[source]#

Check if input is a power of 2.

Parameters:

number (int) – Integer to check.

Return type:

bool

Returns:

True if number is a power of 2.

direct.utils.ensure_list(data)[source]#

Ensure input is a list.

Parameters:

data (Any) – Object to convert to list.

Return type:

List

Returns:

List representation of input.

direct.utils.cast_as_path(data)[source]#

Ensure the the input is a path.

Parameters:

data (Union[Path, str, None]) – String or path to convert.

Return type:

Optional[Path]

Returns:

Path object or None if input is None.

direct.utils.str_to_class(module_name, function_name)[source]#

Convert a string to a class.

Base on: https://stackoverflow.com/a/1176180/576363.

Also support function arguments, e.g. ifft(dim=2) will be parsed as a partial and return ifft where dim has been set to 2.

Parameters:
  • module_name (str) – Module name, e.g. "direct.data.transforms".

  • function_name (str) – Function or class name, e.g. "Identity".

Return type:

Callable

Returns:

The requested class or function object.

Example

>>> def mult(f, mul=2):
...    return f*mul
>>> str_to_class(".", "mult(mul=4)")

Will return a function which multiplies the input times 4, while

>>> str_to_class(".", "mult")

just returns the function itself.

direct.utils.dict_to_device(data, device, keys=None)[source]#

Copy tensor-valued dictionary to device.

Only torch.Tensor is copied.

Parameters:
  • data (Dict[str, Tensor]) – Dictionary with tensor values.

  • device (Union[device, str, None]) – Target device.

  • keys (Union[List, Tuple, KeysView, None]) – Subselection of keys to copy. Default: None.

Return type:

Dict

Returns:

Dictionary at the new device.

direct.utils.detach_dict(data, keys=None)[source]#

Return a detached copy of a dictionary.

Only torch.Tensor’s are detached.

Parameters:
  • data (Dict[str, Tensor]) – Dictionary with tensor values.

  • keys (Union[List, Tuple, KeysView, None]) – Subselection of keys to detach. Default: None.

Return type:

Dict

Returns:

Detached dictionary.

direct.utils.reduce_list_of_dicts(data, mode='average', divisor=None)[source]#

Average a list of dictionary mapping keys to Tensors.

Parameters:
  • data (List[Dict[str, Tensor]]) – List of dictionaries with tensor values.

  • mode – Which reduction mode. "average" computes the average, "sum" just adds. Default: "average".

  • divisor – If given, values are divided by this factor. Default: None.

Return type:

Dict[str, Tensor]

Returns:

Reduced dictionary.

direct.utils.merge_list_of_dicts(list_of_dicts)[source]#

A list of dictionaries is merged into one dictionary.

Parameters:

list_of_dicts (List[Dict]) – List of dictionaries to merge.

Return type:

Dict

Returns:

Merged dictionary.

direct.utils.evaluate_dict(fns_dict, source, target, reduction='mean')[source]#

Evaluate a dictionary of functions.

Parameters:
  • fns_dict (Dict[str, Callable]) – Dictionary mapping names to callable functions.

  • source (Tensor) – Source tensor.

  • target (Tensor) – Target tensor.

  • reduction (str) – Reduction mode. Default: "mean".

Return type:

Dict

Returns:

Evaluated dictionary.

Example

>>> evaluate_dict({'l1_loss': F.l1_loss, 'l2_loss': F.l2_loss}, a, b)

Will return

>>> {'l1_loss': F.l1_loss(a, b, reduction='mean'), 'l2_loss': F.l2_loss(a, b, reduction='mean')}
direct.utils.prefix_dict_keys(data, prefix)[source]#

Append a prefix to a dictionary keys.

Parameters:
  • data (Dict[str, Any]) – Dictionary to prefix.

  • prefix (str) – Prefix string to add.

Return type:

Dict[str, Any]

Returns:

Dictionary with prefixed keys.

direct.utils.git_hash()[source]#

Returns the current git hash.

Return type:

str

Returns:

The current git hash.

direct.utils.normalize_image(image, eps=0.00001)[source]#

Normalize image to range [0,1] for visualization.

Given image \(x\) and \(\epsilon\), it returns:

\[\frac{x - \min{x}}{\max{x} + \epsilon}.\]
Parameters:
  • image (Tensor) – Image to scale.

  • eps (float) – Epsilon value to avoid division by zero. Default: 0.00001.

Return type:

Tensor

Returns:

Scaled image.

direct.utils.multiply_function(multiplier, func)[source]#

Create a function which multiplies another one with a multiplier.

Parameters:
  • multiplier (float) – Number to multiply with.

  • func (Callable) – Function to multiply.

Return type:

Callable

Returns:

Multiplied function.

class direct.utils.SpatialDims(TWO_D, THREE_D)#

Bases: tuple

THREE_D#

Alias for field number 1

TWO_D#

Alias for field number 0

class direct.utils.DirectTransform[source]#

Bases: object

Direct transform class.

Defines __repr__() method for Direct transforms.

__init__()[source]#

Inits DirectTransform.

__repr__()[source]#

Representation of DirectTransform.

class direct.utils.DirectModule[source]#

Bases: DirectTransform, ABC, Module

abstract __init__()[source]#

Inits DirectTransform.

forward(sample)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

direct.utils.count_parameters(models)[source]#

Count the number of parameters of a dictionary of models.

Parameters:

models (Dict) – Dictionary mapping model name to model.

Return type:

None

direct.utils.set_all_seeds(seed)[source]#

Sets seed for deterministic runs.

Parameters:

seed (int) – Seed for random module.

Return type:

None

direct.utils.chunks(list_to_chunk, number_of_chunks)[source]#

Yield number_of_chunks number of sequential chunks from list_to_chunk.

Adapted from [1].

Parameters:
  • list_to_chunk (List) – List to chunk.

  • number_of_chunks (int) – Number of chunks to create.

References

direct.utils.remove_keys(input_dict, keys)[source]#

Removes keys from input_dict.

Parameters:
  • input_dict (Dict) – Input dictionary.

  • keys (Union[str, List[str], Tuple[str]]) – Key or keys to remove.

Return type:

Dict

Returns:

Dictionary with keys removed.

direct.utils.dict_flatten(in_dict, dict_out=None)[source]#

Flattens a nested dictionary (or DictConfig) and returns a new flattened dictionary.

If a dict_out is provided, the flattened dictionary will be added to it.

Parameters:
  • in_dict (Union[dict, DictConfig]) – The nested dictionary or DictConfig to flatten.

  • dict_out (Union[dict, DictConfig, None]) – An existing dictionary to add the flattened dictionary to. Default: None.

Return type:

Dict[str, Any]

Returns:

The flattened dictionary.

Note

This function only keeps the final keys, and discards the intermediate ones.

Example

>>> dictA = {"a": 1, "b": {"c": 2, "d": 3, "e": {"f": 4, 6: "a", 5: {"g": 6}, "l": [1, "two"]}}}
>>> dict_flatten(dictA)
{'a': 1, 'c': 2, 'd': 3, 'f': 4, 6: 'a', 'g': 6, 'l': [1, 'two']}