direct.utils package#
Submodules#
direct.utils.asserts module#
- direct.utils.asserts.assert_complex(data, complex_axis=-1, complex_last=None)[source][source]#
Assert if a tensor is complex (has complex dimension of size 2 corresponding to real and imaginary channels).
- Parameters:
- data: torch.Tensor
- complex_axis: int
Complex dimension along which the assertion will be done. Default: -1 (last).
- complex_last: Optional[bool]
If true, will override complex_axis with -1 (last). Default: None.
- Return type:
None
direct.utils.bbox module#
- direct.utils.bbox.crop_to_bbox(data, bbox, pad_value=0)[source][source]#
Extract bbox from images, coordinates can be negative.
- Parameters:
- data: np.ndarray or torch.tensor
nD array or torch tensor.
- bbox: list or tuple
bbox of the form (coordinates, size), for instance (4, 4, 2, 1) is a patch starting at row 4, col 4 with height 2 and width 1.
- pad_value: number
if bounding box would be out of the image, this is value the patch will be padded with.
- Returns:
- ndarray
Numpy array of data cropped to BoundingBox
- Return type:
Union
[ndarray
,Tensor
]
- direct.utils.bbox.crop_to_largest(data, pad_value=0)[source][source]#
Given a list of arrays or tensors, return the same list with the data padded to the largest in the set. Can be convenient for e.g. logging and tiling several images as with torchvision’s make_grid’
- Parameters:
- data: List[Union[np.ndarray, torch.Tensor]]
- pad_value: int
- Returns:
- List[Union[np.ndarray, torch.Tensor]]
- Return type:
List
[Union
[ndarray
,Tensor
]]
direct.utils.communication module#
- direct.utils.communication.all_gather(data, group=None)[source][source]#
Run all_gather on arbitrary picklable data (not necessarily tensors).
- Parameters:
- data: object
Any pickleable object.
- group
A torch process group. By default, will use a group which contains all ranks on gloo backend.
- Returns:
- list: list of data gathered for each rank
- direct.utils.communication.gather(data, destination_rank=0, group=None)[source][source]#
Run gather on arbitrary picklable data (not necessarily tensors).
- Parameters:
- data: object
Any pickleable object
- destination_rank: int
Destination rank
- group
A torch process group. By default, will use a group which contains all ranks on gloo backend.
- Returns:
- list[data]: on destination_rank, a list of data gathered from each rank. Otherwise, an empty list.
- Return type:
List
- direct.utils.communication.get_local_rank()[source][source]#
Get rank of the process within the same machine, even when torch.distributed is not initialized.
- Returns:
- int: The rank of the current process within the local (per-machine) process group.
- Return type:
int
- direct.utils.communication.get_local_size()[source][source]#
Number of compute units in local machine.
- Returns:
- int
- Return type:
int
- direct.utils.communication.get_rank()[source][source]#
Get rank of the process, even when torch.distributed is not initialized.
- Returns:
- int
- Return type:
int
- direct.utils.communication.get_world_size()[source][source]#
Get number of compute device in the world, returns 1 in case multi device is not initialized.
- Returns:
- int
- Return type:
int
- direct.utils.communication.is_main_process()[source][source]#
Simple wrapper around get_rank().
- Returns:
- bool
- Return type:
bool
- direct.utils.communication.reduce_tensor_dict(tensors_dict)[source][source]#
Reduce the tensor dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as tensors_dict, after reduction.
- Parameters:
- tensors_dict: dict
dictionary with str keys mapping to torch tensors.
- Returns
- ——-
- dict: the reduced dict.
- Return type:
Dict
[str
,Tensor
]
All workers must call this function, otherwise it will deadlock.
- Returns:
- A random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to
- create one.
- Return type:
int
direct.utils.dataset module#
- direct.utils.dataset.get_filenames_for_datasets(lists, files_root, data_root)[source][source]#
Given lists of filenames of data points, concatenate these into a large list of full filenames.
- Parameters:
- lists: List[PathOrString]
- files_root: PathOrString
- data_root: pathlib.Path
- Returns:
- list of filenames or None
- direct.utils.dataset.get_filenames_for_datasets_from_config(cfg, files_root, data_root)[source][source]#
Given a configuration object it returns a list of filenames.
- Parameters:
- cfg: cfg-object
cfg object having property lists having the relative paths compared to files root.
- files_root: Union[str, pathlib.Path]
- data_root: pathlib.Path
- Returns:
- list of filenames or None
direct.utils.events module#
- class direct.utils.events.CommonMetricPrinter(max_iter)[source][source]#
Bases:
EventWriter
Print common metrics to the terminal, including iteration time, ETA, memory, all losses, and the learning rate.
To print something different, please implement a similar printer by yourself.
- class direct.utils.events.EventStorage(start_iter=0)[source][source]#
Bases:
object
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
- add_graph(img_name, img_tensor)[source][source]#
Add an img_tensor to the _vis_data associated with img_name.
- Parameters:
- img_name: str
The name of the input_image to put into tensorboard.
- img_tensor: torch.Tensor or numpy.array
An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard.
- add_image(img_name, img_tensor)[source][source]#
Add an img_tensor to the _vis_data associated with img_name.
- Parameters:
- img_name: str
The name of the input_image to put into tensorboard.
- img_tensor: torch.Tensor or numpy.array
An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard.
- add_scalar(name, value, smoothing_hint=True)[source][source]#
Add a scalar value to the HistoryBuffer associated with name.
- Parameters:
- name: str
- value: float
- smoothing_hint: bool
A ‘hint’ on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through EventStorage.smoothing_hints. A writer may ignore the hint and apply custom smoothing rule. It defaults to True because most scalars we save need to be smoothed to provide any useful signal.
- Returns:
- add_scalars(*, smoothing_hint=True, **kwargs)[source][source]#
Put multiple scalars from keyword arguments.
Examples
storage.add_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
- clear_images()[source][source]#
Delete all the stored images for visualization.
This should be called after images are written to tensorboard.
- property iter#
- latest()[source][source]#
- Returns:
dict[name -> number]: the scalars that’s added in the current iteration.
- latest_with_smoothing_hint(window_size=20)[source][source]#
Similar to
latest()
, but the returned values are either the un-smoothed original latest value, or a median of the given window_size, depend on whether the smoothing_hint is True.This provides a default behavior that other writers can use.
- name_scope(name)[source][source]#
- Yields:
A context within which all the events added to this storage will be prefixed by the name scope.
- smoothing_hints()[source][source]#
- Returns:
- dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
- step()[source][source]#
User should call this function at the beginning of each iteration, to notify the storage of the start of a new iteration.
The storage will then be able to associate the new data with the correct iteration number.
- property vis_data#
- class direct.utils.events.EventWriter[source][source]#
Bases:
object
Base class for writers that obtain events from
EventStorage
and process them.
- class direct.utils.events.HistoryBuffer(max_length=1000000)[source][source]#
Bases:
object
Track a series of scalar values and provide access to smoothed values over a window or the global average of the series.
- avg(window_size)[source][source]#
Return the mean of the latest window_size values in the buffer.
- Return type:
float
- global_avg()[source][source]#
Return the mean of all the elements in the buffer.
Note that this includes those getting removed due to limited buffer storage.
- Return type:
float
- median(window_size)[source][source]#
Return the median of the latest window_size values in the buffer.
- Return type:
float
- class direct.utils.events.JSONWriter(json_file, window_size=2)[source][source]#
Bases:
EventWriter
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
$ cat metrics.json | jq -s '.[0:2]' [ { "data_time": 0.008433341979980469, "iteration": 20, "loss": 1.9228371381759644, "loss_box_reg": 0.050025828182697296, "loss_classifier": 0.5316952466964722, "loss_mask": 0.7236229181289673, "loss_rpn_box": 0.0856662318110466, "loss_rpn_cls": 0.48198649287223816, "lr": 0.007173333333333333, "time": 0.25401854515075684 }, { "data_time": 0.007216215133666992, "iteration": 40, "loss": 1.282649278640747, "loss_box_reg": 0.06222952902317047, "loss_classifier": 0.30682939291000366, "loss_mask": 0.6970193982124329, "loss_rpn_box": 0.038663312792778015, "loss_rpn_cls": 0.1471673548221588, "lr": 0.007706666666666667, "time": 0.2490077018737793 } ] $ cat metrics.json | jq '.loss_mask' 0.7126231789588928 0.689423680305481 0.6776131987571716 ...
- class direct.utils.events.TensorboardWriter(log_dir, window_size=20, **kwargs)[source][source]#
Bases:
EventWriter
Write all scalars to a tensorboard file.
- direct.utils.events.get_event_storage()[source][source]#
- Returns:
The
EventStorage
object that’s currently being used. Throws an error if no :class`EventStorage` is currently enabled.
direct.utils.imports module#
General utilities for module imports.
direct.utils.io module#
- class direct.utils.io.ArrayEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source][source]#
Bases:
JSONEncoder
- default(obj)[source][source]#
Implement this method in a subclass such that it returns a serializable object for
o
, or calls the base implementation (to raise aTypeError
).For example, to support arbitrary iterators, you could implement default like this:
def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) # Let the base class default method raise the TypeError return JSONEncoder.default(self, o)
- direct.utils.io.check_is_valid_url(path)[source][source]#
Check if the given path is a valid url.
- Parameters:
- path: PathOrString
- Returns:
- Bool describing if this is an URL or not.
- Return type:
bool
- direct.utils.io.download_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False)[source][source]#
- Return type:
None
- direct.utils.io.download_url(url, root, filename=None, md5=None, max_redirect_hops=3)[source][source]#
Download a file from a url and place it in root.
- Parameters:
- url: str
URL to download file from
- root: PathOrString
Directory to place downloaded file in
- filename: str, optional:
Name to save the file under. If None, use the basename of the URL
- md5: str, optional
MD5 checksum of the download. If None, do not check
- max_redirect_hops: int, optional)
Maximum number of redirect hops allowed
- Return type:
None
- direct.utils.io.extract_archive(from_path, to_path=None, remove_finished=False)[source][source]#
Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed but not an archive the call is dispatched to
decompress()
.- Parameters:
- from_path: str
Path to the file to be extracted.
- to_path: str
Path to the directory the file will be extracted to. If omitted, the directory of the file is used.
- remove_finished: bool
If
True
, remove the file after the extraction.
- Returns:
- str
Path to the directory the file was extracted to.
- Return type:
str
- direct.utils.io.read_json(fn)[source][source]#
Read file and output dict, or take dict and output dict.
- Parameters:
- fn: Union[Dict, str, pathlib.Path]
- Returns:
- dict
- Return type:
Dict
- direct.utils.io.read_list(fn)[source][source]#
Read file and output list, or take list and output list. Can read data from URLs.
- Parameters:
- fn: Union[[list, str, pathlib.Path]]
Input text file or list, or a URL to a text file.
- Returns:
- list
Text file read line by line.
- Return type:
List
- direct.utils.io.read_text_from_url(url, chunk_size=1024)[source][source]#
Read a text file from a URL, e.g. a config file.
- Parameters:
- url: str
- chunk_size: int
- Returns:
- str
Data from URL
- direct.utils.io.upload_to_s3(filename, to_filename, aws_access_key_id, aws_secret_access_key, endpoint_url='https://s3.aiforoncology.nl', bucket='direct-project', verbose=True)[source][source]#
Upload file to an s3 bucket.
- Parameters:
- filenamepathlib.Path
Filename to upload
- to_filenamestr
Where to store the file
- aws_access_key_idstr
- aws_secret_access_keystr
- endpoint_urlstr
AWS endpoint url
- bucketstr
Bucket name
- verbosestr
Show upload progress
- Returns:
- None
- Return type:
None
direct.utils.logging module#
- direct.utils.logging.setup(use_stdout=True, filename=None, log_level='INFO')[source][source]#
Setup logging for DIRECT.
- Parameters:
- use_stdout: bool
Write output to standard out.
- filename: PathLike
Filename to write log to.
- log_level: str
Logging level as in the python.logging library.
- Returns:
- None
- Return type:
None
direct.utils.models module#
- direct.utils.models.fix_state_dict_module_prefix(state_dict)[source][source]#
If models are saved after being wrapped in e.g. DataParallel, the keys of the state dict are prefixed with module.. This function removes this prefix.
- Parameters:
- state_dict: dict
state_dict of a network module
- Returns
- ——-
- dict
direct.utils.writers module#
- direct.utils.writers.write_output_to_h5(output, output_directory, volume_processing_func=None, output_key='reconstruction', create_dirs_if_needed=True)[source][source]#
Write dictionary with keys filenames and values torch tensors to h5 files.
- Parameters:
- output: dict
Dictionary with keys filenames and values torch.Tensor’s with shape [depth, num_channels, …] where num_channels is typically 1 for MRI.
- output_directory: pathlib.Path
- volume_processing_func: callable
Function which postprocesses the volume array before saving.
- output_key: str
Name of key to save the output to.
- create_dirs_if_needed: bool
If true, the output directory and all its parents will be created.
- Return type:
None
Notes
Currently only num_channels = 1 is supported. If you run this function with more channels the first one will be used.
Module contents#
direct.utils module.
- class direct.utils.DirectModule[source][source]#
Bases:
DirectTransform
,ABC
,Module
- forward(sample)[source][source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool
#
- class direct.utils.DirectTransform[source][source]#
Bases:
object
Direct transform class.
Defines
__repr__()
method for Direct transforms.
- class direct.utils.SpatialDims(TWO_D, THREE_D)[source]#
Bases:
tuple
- THREE_D#
Alias for field number 1
- TWO_D#
Alias for field number 0
- direct.utils.cast_as_path(data)[source][source]#
Ensure the the input is a path.
- Parameters:
- data: str or pathlib.Path
- Returns:
- pathlib.Path
- Return type:
Optional
[Path
]
- direct.utils.chunks(list_to_chunk, number_of_chunks)[source][source]#
Yield number_of_chunks number of sequential chunks from list_to_chunk. Adapted from [1].
- Parameters:
- list_to_chunk: List
- number_of_chunks: int
References
- direct.utils.count_parameters(models)[source][source]#
Count the number of parameters of a dictionary of models.
- Parameters:
- models: Dict
Dictionary mapping model name to model.
- Return type:
None
- direct.utils.detach_dict(data, keys=None)[source][source]#
Return a detached copy of a dictionary. Only torch.Tensor’s are detached.
- Parameters:
- data: Dict[str, torch.Tensor]
- keys: List, Tuple
Subselection of keys to detach
- Returns:
- Dictionary at the new device.
- Return type:
Dict
- direct.utils.dict_flatten(in_dict, dict_out=None)[source][source]#
Flattens a nested dictionary (or DictConfig) and returns a new flattened dictionary.
If a dict_out is provided, the flattened dictionary will be added to it.
- Parameters:
- in_dictDictOrDictConfig
The nested dictionary or DictConfig to flatten.
- dict_outOptional[DictOrDictConfig], optional
An existing dictionary to add the flattened dictionary to. Default: None.
- Returns:
- Dict[str, Any]
The flattened dictionary.
- Return type:
Dict
[str
,Any
]
Notes
This function only keeps the final keys, and discards the intermediate ones.
Examples
>>> dictA = {"a": 1, "b": {"c": 2, "d": 3, "e": {"f": 4, 6: "a", 5: {"g": 6}, "l": [1, "two"]}}} >>> dict_flatten(dictA) {'a': 1, 'c': 2, 'd': 3, 'f': 4, 6: 'a', 'g': 6, 'l': [1, 'two']}
- direct.utils.dict_to_device(data, device, keys=None)[source][source]#
Copy tensor-valued dictionary to device. Only torch.Tensor is copied.
- Parameters:
- data: Dict[str, torch.Tensor]
- device: torch.device, str
- keys: List, Tuple
Subselection of keys to copy.
- Returns:
- Dictionary at the new device.
- Return type:
Dict
- direct.utils.ensure_list(data)[source][source]#
Ensure input is a list.
- Parameters:
- data: object
- Returns:
- list
- Return type:
List
- direct.utils.evaluate_dict(fns_dict, source, target, reduction='mean')[source][source]#
Evaluate a dictionary of functions.
- Parameters:
- fns_dict: Dict[str, Callable]
- source: torch.Tensor
- target: torch.Tensor
- reduction: str
- Returns:
- Dict[str, torch.Tensor]
Evaluated dictionary.
- Return type:
Dict
Examples
> evaluate_dict({‘l1_loss: F.l1_loss, ‘l2_loss’: F.l2_loss}, a, b)
Will return > {‘l1_loss’, F.l1_loss(a, b, reduction=’mean’), ‘l2_loss’: F.l2_loss(a, b, reduction=’mean’)
- direct.utils.git_hash()[source][source]#
Returns the current git hash.
- Returns:
- _git_hash: str
The current git hash.
- Return type:
str
- direct.utils.is_complex_data(data, complex_axis=-1)[source][source]#
Returns True if data is a complex tensor at a specified dimension, i.e. complex_axis of data is of size 2, corresponding to real and imaginary channels..
- Parameters:
- data: torch.Tensor
- complex_axis: int
Complex dimension along which the check will be done. Default: -1 (last).
- Returns:
- bool
True if data is a complex tensor.
- Return type:
bool
- direct.utils.is_power_of_two(number)[source][source]#
Check if input is a power of 2.
- Parameters:
- number: int
- Returns:
- bool
- Return type:
bool
- direct.utils.merge_list_of_dicts(list_of_dicts)[source][source]#
A list of dictionaries is merged into one dictionary.
- Parameters:
- list_of_dicts: List[Dict]
- Returns:
- Dict
- Return type:
Dict
- direct.utils.multiply_function(multiplier, func)[source][source]#
Create a function which multiplier another one with a multiplier.
- Parameters:
- multiplier: float
Number to multiply with.
- func: callable
Function to multiply.
- Returns:
- return_func: Callable
- Return type:
Callable
- direct.utils.normalize_image(image, eps=1e-05)[source][source]#
Normalize image to range [0,1] for visualization.
Given image \(x\) and \(\epsilon\), it returns: :rtype:
Tensor
\[\frac{x - \min{x}}{\max{x} + \epsilon}.\]- Parameters:
- image: torch.Tensor
Image to scale.
- eps: float
- Returns:
- image: torch.Tensor
Scaled image.
- direct.utils.prefix_dict_keys(data, prefix)[source][source]#
Append a prefix to a dictionary keys.
- Parameters:
- data: Dict[str, Any]
- prefix: str
- Returns:
- Dict[str, Any]
- Return type:
Dict
[str
,Any
]
- direct.utils.reduce_list_of_dicts(data, mode='average', divisor=None)[source][source]#
Average a list of dictionary mapping keys to Tensors.
- Parameters:
- data: List[Dict[str, torch.Tensor]])
- mode: str
Which reduction mode, average reduces the dictionary, sum just adds while average computes the average.
- divisor: None or int
If given values are divided by this factor.
- Returns:
- Dict[str, torch.Tensor]: Reduced dictionary.
- Return type:
Dict
[str
,Tensor
]
- direct.utils.remove_keys(input_dict, keys)[source][source]#
Removes keys from input_dict.
- Parameters:
- input_dict: Dict
- keys: Union[str, List[str], Tuple[str]]
- Returns:
- Dict
- Return type:
Dict
- direct.utils.set_all_seeds(seed)[source][source]#
Sets seed for deterministic runs.
- Parameters:
- seed: int
Seed for random module.
- Returns:
- Return type:
None
- direct.utils.str_to_class(module_name, function_name)[source][source]#
Convert a string to a class Base on: https://stackoverflow.com/a/1176180/576363.
Also support function arguments, e.g. ifft(dim=2) will be parsed as a partial and return ifft where dim has been set to 2.
- Parameters:
- module_name: str
e.g. direct.data.transforms
- function_name: str
e.g. Identity
- Returns
- ——-
- object
- Return type:
Callable
Examples
>>> def mult(f, mul=2): >>> return f*mul
>>> str_to_class(".", "mult(mul=4)") >>> str_to_class(".", "mult(mul=4)") will return a function which multiplies the input times 4, while
>>> str_to_class(".", "mult") just returns the function itself.