direct.utils package

Contents

direct.utils package#

Submodules#

direct.utils.asserts module#

direct.utils.asserts.assert_complex(data, complex_axis=-1, complex_last=None)[source][source]#

Assert if a tensor is complex (has complex dimension of size 2 corresponding to real and imaginary channels).

Parameters:
data: torch.Tensor
complex_axis: int

Complex dimension along which the assertion will be done. Default: -1 (last).

complex_last: Optional[bool]

If true, will override complex_axis with -1 (last). Default: None.

Return type:

None

direct.utils.asserts.assert_positive_integer(*variables, strict=False)[source][source]#

Assert if given variables are positive integer.

Parameters:
variables: Any
strict: bool

If true, will allow zero values.

Return type:

None

direct.utils.asserts.assert_same_shape(data_list)[source][source]#

Check if all tensors in the list have the same shape.

Parameters:
data_list: list

List of tensors

direct.utils.bbox module#

direct.utils.bbox.crop_to_bbox(data, bbox, pad_value=0)[source][source]#

Extract bbox from images, coordinates can be negative.

Parameters:
data: np.ndarray or torch.tensor

nD array or torch tensor.

bbox: list or tuple

bbox of the form (coordinates, size), for instance (4, 4, 2, 1) is a patch starting at row 4, col 4 with height 2 and width 1.

pad_value: number

if bounding box would be out of the image, this is value the patch will be padded with.

Returns:
ndarray

Numpy array of data cropped to BoundingBox

Return type:

Union[ndarray, Tensor]

direct.utils.bbox.crop_to_largest(data, pad_value=0)[source][source]#

Given a list of arrays or tensors, return the same list with the data padded to the largest in the set. Can be convenient for e.g. logging and tiling several images as with torchvision’s make_grid’

Parameters:
data: List[Union[np.ndarray, torch.Tensor]]
pad_value: int
Returns:
List[Union[np.ndarray, torch.Tensor]]
Return type:

List[Union[ndarray, Tensor]]

direct.utils.communication module#

direct.utils.communication.all_gather(data, group=None)[source][source]#

Run all_gather on arbitrary picklable data (not necessarily tensors).

Parameters:
data: object

Any pickleable object.

group

A torch process group. By default, will use a group which contains all ranks on gloo backend.

Returns:
list: list of data gathered for each rank
direct.utils.communication.gather(data, destination_rank=0, group=None)[source][source]#

Run gather on arbitrary picklable data (not necessarily tensors).

Parameters:
data: object

Any pickleable object

destination_rank: int

Destination rank

group

A torch process group. By default, will use a group which contains all ranks on gloo backend.

Returns:
list[data]: on destination_rank, a list of data gathered from each rank. Otherwise, an empty list.
Return type:

List

direct.utils.communication.get_local_rank()[source][source]#

Get rank of the process within the same machine, even when torch.distributed is not initialized.

Returns:
int: The rank of the current process within the local (per-machine) process group.
Return type:

int

direct.utils.communication.get_local_size()[source][source]#

Number of compute units in local machine.

Returns:
int
Return type:

int

direct.utils.communication.get_rank()[source][source]#

Get rank of the process, even when torch.distributed is not initialized.

Returns:
int
Return type:

int

direct.utils.communication.get_world_size()[source][source]#

Get number of compute device in the world, returns 1 in case multi device is not initialized.

Returns:
int
Return type:

int

direct.utils.communication.is_main_process()[source][source]#

Simple wrapper around get_rank().

Returns:
bool
Return type:

bool

direct.utils.communication.reduce_tensor_dict(tensors_dict)[source][source]#

Reduce the tensor dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as tensors_dict, after reduction.

Parameters:
tensors_dict: dict

dictionary with str keys mapping to torch tensors.

Returns
——-
dict: the reduced dict.
Return type:

Dict[str, Tensor]

direct.utils.communication.shared_random_seed()[source][source]#

All workers must call this function, otherwise it will deadlock.

Returns:
A random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to
create one.
Return type:

int

direct.utils.communication.synchronize()[source][source]#

Synchronize processes between GPUs.

Wait until all devices are available. Does nothing in a non-distributed setting.

direct.utils.dataset module#

direct.utils.dataset.get_filenames_for_datasets(lists, files_root, data_root)[source][source]#

Given lists of filenames of data points, concatenate these into a large list of full filenames.

Parameters:
lists: List[PathOrString]
files_root: PathOrString
data_root: pathlib.Path
Returns:
list of filenames or None
direct.utils.dataset.get_filenames_for_datasets_from_config(cfg, files_root, data_root)[source][source]#

Given a configuration object it returns a list of filenames.

Parameters:
cfg: cfg-object

cfg object having property lists having the relative paths compared to files root.

files_root: Union[str, pathlib.Path]
data_root: pathlib.Path
Returns:
list of filenames or None

direct.utils.events module#

class direct.utils.events.CommonMetricPrinter(max_iter)[source][source]#

Bases: EventWriter

Print common metrics to the terminal, including iteration time, ETA, memory, all losses, and the learning rate.

To print something different, please implement a similar printer by yourself.

write()[source][source]#
class direct.utils.events.EventStorage(start_iter=0)[source][source]#

Bases: object

The user-facing class that provides metric storage functionalities.

In the future we may add support for storing / logging other types of data if needed.

add_graph(img_name, img_tensor)[source][source]#

Add an img_tensor to the _vis_data associated with img_name.

Parameters:
img_name: str

The name of the input_image to put into tensorboard.

img_tensor: torch.Tensor or numpy.array

An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard.

add_image(img_name, img_tensor)[source][source]#

Add an img_tensor to the _vis_data associated with img_name.

Parameters:
img_name: str

The name of the input_image to put into tensorboard.

img_tensor: torch.Tensor or numpy.array

An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard.

add_scalar(name, value, smoothing_hint=True)[source][source]#

Add a scalar value to the HistoryBuffer associated with name.

Parameters:
name: str
value: float
smoothing_hint: bool

A ‘hint’ on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through EventStorage.smoothing_hints. A writer may ignore the hint and apply custom smoothing rule. It defaults to True because most scalars we save need to be smoothed to provide any useful signal.

Returns:
add_scalars(*, smoothing_hint=True, **kwargs)[source][source]#

Put multiple scalars from keyword arguments.

Examples

storage.add_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)

clear_images()[source][source]#

Delete all the stored images for visualization.

This should be called after images are written to tensorboard.

histories()[source][source]#
Returns:

dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars

history(name)[source][source]#
Returns:

HistoryBuffer: the scalar history for name

property iter#
latest()[source][source]#
Returns:

dict[name -> number]: the scalars that’s added in the current iteration.

latest_with_smoothing_hint(window_size=20)[source][source]#

Similar to latest(), but the returned values are either the un-smoothed original latest value, or a median of the given window_size, depend on whether the smoothing_hint is True.

This provides a default behavior that other writers can use.

name_scope(name)[source][source]#
Yields:

A context within which all the events added to this storage will be prefixed by the name scope.

smoothing_hints()[source][source]#
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar

is noisy and needs smoothing.

step()[source][source]#

User should call this function at the beginning of each iteration, to notify the storage of the start of a new iteration.

The storage will then be able to associate the new data with the correct iteration number.

property vis_data#
class direct.utils.events.EventWriter[source][source]#

Bases: object

Base class for writers that obtain events from EventStorage and process them.

close()[source][source]#
write()[source][source]#
class direct.utils.events.HistoryBuffer(max_length=1000000)[source][source]#

Bases: object

Track a series of scalar values and provide access to smoothed values over a window or the global average of the series.

avg(window_size)[source][source]#

Return the mean of the latest window_size values in the buffer.

Return type:

float

global_avg()[source][source]#

Return the mean of all the elements in the buffer.

Note that this includes those getting removed due to limited buffer storage.

Return type:

float

latest()[source][source]#

Return the latest scalar value added to the buffer.

Return type:

float

median(window_size)[source][source]#

Return the median of the latest window_size values in the buffer.

Return type:

float

update(value, iteration=None)[source][source]#

Add a new scalar value produced at certain iteration.

If the length of the buffer exceeds self._max_length, the oldest element will be removed from the buffer.

Return type:

None

values()[source][source]#
Return type:

List[Tuple[float, float]]

Returns:

list[(number, iteration)]: content of the current buffer.

class direct.utils.events.JSONWriter(json_file, window_size=2)[source][source]#

Bases: EventWriter

Write scalars to a json file.

It saves scalars as one json per line (instead of a big json) for easy parsing.

Examples parsing such a json file:

$ cat metrics.json | jq -s '.[0:2]'
[
  {
    "data_time": 0.008433341979980469,
    "iteration": 20,
    "loss": 1.9228371381759644,
    "loss_box_reg": 0.050025828182697296,
    "loss_classifier": 0.5316952466964722,
    "loss_mask": 0.7236229181289673,
    "loss_rpn_box": 0.0856662318110466,
    "loss_rpn_cls": 0.48198649287223816,
    "lr": 0.007173333333333333,
    "time": 0.25401854515075684
  },
  {
    "data_time": 0.007216215133666992,
    "iteration": 40,
    "loss": 1.282649278640747,
    "loss_box_reg": 0.06222952902317047,
    "loss_classifier": 0.30682939291000366,
    "loss_mask": 0.6970193982124329,
    "loss_rpn_box": 0.038663312792778015,
    "loss_rpn_cls": 0.1471673548221588,
    "lr": 0.007706666666666667,
    "time": 0.2490077018737793
  }
]

$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
close()[source][source]#
write()[source][source]#
class direct.utils.events.TensorboardWriter(log_dir, window_size=20, **kwargs)[source][source]#

Bases: EventWriter

Write all scalars to a tensorboard file.

close()[source][source]#
write()[source][source]#
direct.utils.events.get_event_storage()[source][source]#
Returns:

The EventStorage object that’s currently being used. Throws an error if no :class`EventStorage` is currently enabled.

direct.utils.imports module#

General utilities for module imports.

direct.utils.io module#

class direct.utils.io.ArrayEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source][source]#

Bases: JSONEncoder

default(obj)[source][source]#

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this:

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return JSONEncoder.default(self, o)
direct.utils.io.calculate_md5(fpath, chunk_size=1048576)[source][source]#
Return type:

str

direct.utils.io.check_integrity(fpath, md5=None)[source][source]#
Return type:

bool

direct.utils.io.check_is_valid_url(path)[source][source]#

Check if the given path is a valid url.

Parameters:
path: PathOrString
Returns:
Bool describing if this is an URL or not.
Return type:

bool

direct.utils.io.check_md5(fpath, md5, **kwargs)[source][source]#
Return type:

bool

direct.utils.io.download_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False)[source][source]#
Return type:

None

direct.utils.io.download_url(url, root, filename=None, md5=None, max_redirect_hops=3)[source][source]#

Download a file from a url and place it in root.

Parameters:
url: str

URL to download file from

root: PathOrString

Directory to place downloaded file in

filename: str, optional:

Name to save the file under. If None, use the basename of the URL

md5: str, optional

MD5 checksum of the download. If None, do not check

max_redirect_hops: int, optional)

Maximum number of redirect hops allowed

Return type:

None

direct.utils.io.extract_archive(from_path, to_path=None, remove_finished=False)[source][source]#

Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed but not an archive the call is dispatched to decompress().

Parameters:
from_path: str

Path to the file to be extracted.

to_path: str

Path to the directory the file will be extracted to. If omitted, the directory of the file is used.

remove_finished: bool

If True, remove the file after the extraction.

Returns:
str

Path to the directory the file was extracted to.

Return type:

str

direct.utils.io.gen_bar_updater()[source][source]#
Return type:

Callable[[int, int, int], None]

direct.utils.io.read_json(fn)[source][source]#

Read file and output dict, or take dict and output dict.

Parameters:
fn: Union[Dict, str, pathlib.Path]
Returns:
dict
Return type:

Dict

direct.utils.io.read_list(fn)[source][source]#

Read file and output list, or take list and output list. Can read data from URLs.

Parameters:
fn: Union[[list, str, pathlib.Path]]

Input text file or list, or a URL to a text file.

Returns:
list

Text file read line by line.

Return type:

List

direct.utils.io.read_text_from_url(url, chunk_size=1024)[source][source]#

Read a text file from a URL, e.g. a config file.

Parameters:
url: str
chunk_size: int
Returns:
str

Data from URL

direct.utils.io.upload_to_s3(filename, to_filename, aws_access_key_id, aws_secret_access_key, endpoint_url='https://s3.aiforoncology.nl', bucket='direct-project', verbose=True)[source][source]#

Upload file to an s3 bucket.

Parameters:
filenamepathlib.Path

Filename to upload

to_filenamestr

Where to store the file

aws_access_key_idstr
aws_secret_access_keystr
endpoint_urlstr

AWS endpoint url

bucketstr

Bucket name

verbosestr

Show upload progress

Returns:
None
Return type:

None

direct.utils.io.write_json(fn, data, indent=2)[source][source]#

Write dict data to fn.

Parameters:
fn: Path or str
data: dict
indent: int
Returns:
None
Return type:

None

direct.utils.io.write_list(fn, data)[source][source]#

Write list line by line to file.

Parameters:
fn: Union[[list, str, pathlib.Path]]

Input text file or list

data: list or tuple
Returns
——-
None
Return type:

None

direct.utils.logging module#

direct.utils.logging.setup(use_stdout=True, filename=None, log_level='INFO')[source][source]#

Setup logging for DIRECT.

Parameters:
use_stdout: bool

Write output to standard out.

filename: PathLike

Filename to write log to.

log_level: str

Logging level as in the python.logging library.

Returns:
None
Return type:

None

direct.utils.models module#

direct.utils.models.fix_state_dict_module_prefix(state_dict)[source][source]#

If models are saved after being wrapped in e.g. DataParallel, the keys of the state dict are prefixed with module.. This function removes this prefix.

Parameters:
state_dict: dict

state_dict of a network module

Returns
——-
dict

direct.utils.writers module#

direct.utils.writers.write_output_to_h5(output, output_directory, volume_processing_func=None, output_key='reconstruction', create_dirs_if_needed=True)[source][source]#

Write dictionary with keys filenames and values torch tensors to h5 files.

Parameters:
output: dict

Dictionary with keys filenames and values torch.Tensor’s with shape [depth, num_channels, …] where num_channels is typically 1 for MRI.

output_directory: pathlib.Path
volume_processing_func: callable

Function which postprocesses the volume array before saving.

output_key: str

Name of key to save the output to.

create_dirs_if_needed: bool

If true, the output directory and all its parents will be created.

Return type:

None

Notes

Currently only num_channels = 1 is supported. If you run this function with more channels the first one will be used.

Module contents#

direct.utils module.

class direct.utils.DirectModule[source][source]#

Bases: DirectTransform, ABC, Module

forward(sample)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class direct.utils.DirectTransform[source][source]#

Bases: object

Direct transform class.

Defines __repr__() method for Direct transforms.

class direct.utils.SpatialDims(TWO_D, THREE_D)[source]#

Bases: tuple

THREE_D#

Alias for field number 1

TWO_D#

Alias for field number 0

direct.utils.cast_as_path(data)[source][source]#

Ensure the the input is a path.

Parameters:
data: str or pathlib.Path
Returns:
pathlib.Path
Return type:

Optional[Path]

direct.utils.chunks(list_to_chunk, number_of_chunks)[source][source]#

Yield number_of_chunks number of sequential chunks from list_to_chunk. Adapted from [1].

Parameters:
list_to_chunk: List
number_of_chunks: int

References

direct.utils.count_parameters(models)[source][source]#

Count the number of parameters of a dictionary of models.

Parameters:
models: Dict

Dictionary mapping model name to model.

Return type:

None

direct.utils.detach_dict(data, keys=None)[source][source]#

Return a detached copy of a dictionary. Only torch.Tensor’s are detached.

Parameters:
data: Dict[str, torch.Tensor]
keys: List, Tuple

Subselection of keys to detach

Returns:
Dictionary at the new device.
Return type:

Dict

direct.utils.dict_flatten(in_dict, dict_out=None)[source][source]#

Flattens a nested dictionary (or DictConfig) and returns a new flattened dictionary.

If a dict_out is provided, the flattened dictionary will be added to it.

Parameters:
in_dictDictOrDictConfig

The nested dictionary or DictConfig to flatten.

dict_outOptional[DictOrDictConfig], optional

An existing dictionary to add the flattened dictionary to. Default: None.

Returns:
Dict[str, Any]

The flattened dictionary.

Return type:

Dict[str, Any]

Notes

  • This function only keeps the final keys, and discards the intermediate ones.

Examples

>>> dictA = {"a": 1, "b": {"c": 2, "d": 3, "e": {"f": 4, 6: "a", 5: {"g": 6}, "l": [1, "two"]}}}
>>> dict_flatten(dictA)
{'a': 1, 'c': 2, 'd': 3, 'f': 4, 6: 'a', 'g': 6, 'l': [1, 'two']}
direct.utils.dict_to_device(data, device, keys=None)[source][source]#

Copy tensor-valued dictionary to device. Only torch.Tensor is copied.

Parameters:
data: Dict[str, torch.Tensor]
device: torch.device, str
keys: List, Tuple

Subselection of keys to copy.

Returns:
Dictionary at the new device.
Return type:

Dict

direct.utils.ensure_list(data)[source][source]#

Ensure input is a list.

Parameters:
data: object
Returns:
list
Return type:

List

direct.utils.evaluate_dict(fns_dict, source, target, reduction='mean')[source][source]#

Evaluate a dictionary of functions.

Parameters:
fns_dict: Dict[str, Callable]
source: torch.Tensor
target: torch.Tensor
reduction: str
Returns:
Dict[str, torch.Tensor]

Evaluated dictionary.

Return type:

Dict

Examples

> evaluate_dict({‘l1_loss: F.l1_loss, ‘l2_loss’: F.l2_loss}, a, b)

Will return > {‘l1_loss’, F.l1_loss(a, b, reduction=’mean’), ‘l2_loss’: F.l2_loss(a, b, reduction=’mean’)

direct.utils.git_hash()[source][source]#

Returns the current git hash.

Returns:
_git_hash: str

The current git hash.

Return type:

str

direct.utils.is_complex_data(data, complex_axis=-1)[source][source]#

Returns True if data is a complex tensor at a specified dimension, i.e. complex_axis of data is of size 2, corresponding to real and imaginary channels..

Parameters:
data: torch.Tensor
complex_axis: int

Complex dimension along which the check will be done. Default: -1 (last).

Returns:
bool

True if data is a complex tensor.

Return type:

bool

direct.utils.is_power_of_two(number)[source][source]#

Check if input is a power of 2.

Parameters:
number: int
Returns:
bool
Return type:

bool

direct.utils.merge_list_of_dicts(list_of_dicts)[source][source]#

A list of dictionaries is merged into one dictionary.

Parameters:
list_of_dicts: List[Dict]
Returns:
Dict
Return type:

Dict

direct.utils.multiply_function(multiplier, func)[source][source]#

Create a function which multiplier another one with a multiplier.

Parameters:
multiplier: float

Number to multiply with.

func: callable

Function to multiply.

Returns:
return_func: Callable
Return type:

Callable

direct.utils.normalize_image(image, eps=1e-05)[source][source]#

Normalize image to range [0,1] for visualization.

Given image \(x\) and \(\epsilon\), it returns: :rtype: Tensor

\[\frac{x - \min{x}}{\max{x} + \epsilon}.\]
Parameters:
image: torch.Tensor

Image to scale.

eps: float
Returns:
image: torch.Tensor

Scaled image.

direct.utils.prefix_dict_keys(data, prefix)[source][source]#

Append a prefix to a dictionary keys.

Parameters:
data: Dict[str, Any]
prefix: str
Returns:
Dict[str, Any]
Return type:

Dict[str, Any]

direct.utils.reduce_list_of_dicts(data, mode='average', divisor=None)[source][source]#

Average a list of dictionary mapping keys to Tensors.

Parameters:
data: List[Dict[str, torch.Tensor]])
mode: str

Which reduction mode, average reduces the dictionary, sum just adds while average computes the average.

divisor: None or int

If given values are divided by this factor.

Returns:
Dict[str, torch.Tensor]: Reduced dictionary.
Return type:

Dict[str, Tensor]

direct.utils.remove_keys(input_dict, keys)[source][source]#

Removes keys from input_dict.

Parameters:
input_dict: Dict
keys: Union[str, List[str], Tuple[str]]
Returns:
Dict
Return type:

Dict

direct.utils.set_all_seeds(seed)[source][source]#

Sets seed for deterministic runs.

Parameters:
seed: int

Seed for random module.

Returns:
Return type:

None

direct.utils.str_to_class(module_name, function_name)[source][source]#

Convert a string to a class Base on: https://stackoverflow.com/a/1176180/576363.

Also support function arguments, e.g. ifft(dim=2) will be parsed as a partial and return ifft where dim has been set to 2.

Parameters:
module_name: str

e.g. direct.data.transforms

function_name: str

e.g. Identity

Returns
——-
object
Return type:

Callable

Examples

>>> def mult(f, mul=2):
>>>    return f*mul
>>> str_to_class(".", "mult(mul=4)")
>>> str_to_class(".", "mult(mul=4)")
will return a function which multiplies the input times 4, while
>>> str_to_class(".", "mult")
just returns the function itself.