direct.utils package#
Submodules#
direct.utils.asserts module#
- direct.utils.asserts.assert_complex(data, complex_axis=-1, complex_last=None)[source][source]#
- Assert if a tensor is complex (has complex dimension of size 2 corresponding to real and imaginary channels). - Parameters:
- data: torch.Tensor
- complex_axis: int
- Complex dimension along which the assertion will be done. Default: -1 (last). 
- complex_last: Optional[bool]
- If true, will override complex_axis with -1 (last). Default: None. 
 
- Return type:
- None
 
direct.utils.bbox module#
- direct.utils.bbox.crop_to_bbox(data, bbox, pad_value=0)[source][source]#
- Extract bbox from images, coordinates can be negative. - Parameters:
- data: np.ndarray or torch.tensor
- nD array or torch tensor. 
- bbox: list or tuple
- bbox of the form (coordinates, size), for instance (4, 4, 2, 1) is a patch starting at row 4, col 4 with height 2 and width 1. 
- pad_value: number
- if bounding box would be out of the image, this is value the patch will be padded with. 
 
- Returns:
- ndarray
- Numpy array of data cropped to BoundingBox 
 
- Return type:
- Union[- ndarray,- Tensor]
 
- direct.utils.bbox.crop_to_largest(data, pad_value=0)[source][source]#
- Given a list of arrays or tensors, return the same list with the data padded to the largest in the set. Can be convenient for e.g. logging and tiling several images as with torchvision’s make_grid’ - Parameters:
- data: List[Union[np.ndarray, torch.Tensor]]
- pad_value: int
 
- Returns:
- List[Union[np.ndarray, torch.Tensor]]
 
- Return type:
- List[- Union[- ndarray,- Tensor]]
 
direct.utils.communication module#
- direct.utils.communication.all_gather(data, group=None)[source][source]#
- Run all_gather on arbitrary picklable data (not necessarily tensors). - Parameters:
- data: object
- Any pickleable object. 
- group
- A torch process group. By default, will use a group which contains all ranks on gloo backend. 
 
- Returns:
- list: list of data gathered for each rank
 
 
- direct.utils.communication.gather(data, destination_rank=0, group=None)[source][source]#
- Run gather on arbitrary picklable data (not necessarily tensors). - Parameters:
- data: object
- Any pickleable object 
- destination_rank: int
- Destination rank 
- group
- A torch process group. By default, will use a group which contains all ranks on gloo backend. 
 
- Returns:
- list[data]: on destination_rank, a list of data gathered from each rank. Otherwise, an empty list.
 
- Return type:
- List
 
- direct.utils.communication.get_local_rank()[source][source]#
- Get rank of the process within the same machine, even when torch.distributed is not initialized. - Returns:
- int: The rank of the current process within the local (per-machine) process group.
 
- Return type:
- int
 
- direct.utils.communication.get_local_size()[source][source]#
- Number of compute units in local machine. - Returns:
- int
 
- Return type:
- int
 
- direct.utils.communication.get_rank()[source][source]#
- Get rank of the process, even when torch.distributed is not initialized. - Returns:
- int
 
- Return type:
- int
 
- direct.utils.communication.get_world_size()[source][source]#
- Get number of compute device in the world, returns 1 in case multi device is not initialized. - Returns:
- int
 
- Return type:
- int
 
- direct.utils.communication.is_main_process()[source][source]#
- Simple wrapper around get_rank(). - Returns:
- bool
 
- Return type:
- bool
 
- direct.utils.communication.reduce_tensor_dict(tensors_dict)[source][source]#
- Reduce the tensor dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as tensors_dict, after reduction. - Parameters:
- tensors_dict: dict
- dictionary with str keys mapping to torch tensors. 
- Returns
- ——-
- dict: the reduced dict.
 
- Return type:
- Dict[- str,- Tensor]
 
- All workers must call this function, otherwise it will deadlock. - Returns:
- A random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to
- create one.
 
- Return type:
- int
 
direct.utils.dataset module#
- direct.utils.dataset.get_filenames_for_datasets(lists, files_root, data_root)[source][source]#
- Given lists of filenames of data points, concatenate these into a large list of full filenames. - Parameters:
- lists: List[PathOrString]
- files_root: PathOrString
- data_root: pathlib.Path
 
- Returns:
- list of filenames or None
 
 
- direct.utils.dataset.get_filenames_for_datasets_from_config(cfg, files_root, data_root)[source][source]#
- Given a configuration object it returns a list of filenames. - Parameters:
- cfg: cfg-object
- cfg object having property lists having the relative paths compared to files root. 
- files_root: Union[str, pathlib.Path]
- data_root: pathlib.Path
 
- Returns:
- list of filenames or None
 
 
direct.utils.events module#
- class direct.utils.events.CommonMetricPrinter(max_iter)[source][source]#
- Bases: - EventWriter- Print common metrics to the terminal, including iteration time, ETA, memory, all losses, and the learning rate. - To print something different, please implement a similar printer by yourself. 
- class direct.utils.events.EventStorage(start_iter=0)[source][source]#
- Bases: - object- The user-facing class that provides metric storage functionalities. - In the future we may add support for storing / logging other types of data if needed. - add_graph(img_name, img_tensor)[source][source]#
- Add an img_tensor to the _vis_data associated with img_name. - Parameters:
- img_name: str
- The name of the input_image to put into tensorboard. 
- img_tensor: torch.Tensor or numpy.array
- An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard. 
 
 
 - add_image(img_name, img_tensor)[source][source]#
- Add an img_tensor to the _vis_data associated with img_name. - Parameters:
- img_name: str
- The name of the input_image to put into tensorboard. 
- img_tensor: torch.Tensor or numpy.array
- An uint8 or float Tensor of shape [channel, height, width] where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The img_tensor will be visualized in tensorboard. 
 
 
 - add_scalar(name, value, smoothing_hint=True)[source][source]#
- Add a scalar value to the HistoryBuffer associated with name. - Parameters:
- name: str
- value: float
- smoothing_hint: bool
- A ‘hint’ on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through EventStorage.smoothing_hints. A writer may ignore the hint and apply custom smoothing rule. It defaults to True because most scalars we save need to be smoothed to provide any useful signal. 
 
- Returns:
 
 - add_scalars(*, smoothing_hint=True, **kwargs)[source][source]#
- Put multiple scalars from keyword arguments. - Examples - storage.add_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) 
 - clear_images()[source][source]#
- Delete all the stored images for visualization. - This should be called after images are written to tensorboard. 
 - property iter#
 - latest()[source][source]#
- Returns:
- dict[name -> number]: the scalars that’s added in the current iteration. 
 
 - latest_with_smoothing_hint(window_size=20)[source][source]#
- Similar to - latest(), but the returned values are either the un-smoothed original latest value, or a median of the given window_size, depend on whether the smoothing_hint is True.- This provides a default behavior that other writers can use. 
 - name_scope(name)[source][source]#
- Yields:
- A context within which all the events added to this storage will be prefixed by the name scope. 
 
 - smoothing_hints()[source][source]#
- Returns:
- dict[name -> bool]: the user-provided hint on whether the scalar
- is noisy and needs smoothing. 
 
 
 - step()[source][source]#
- User should call this function at the beginning of each iteration, to notify the storage of the start of a new iteration. - The storage will then be able to associate the new data with the correct iteration number. 
 - property vis_data#
 
- class direct.utils.events.EventWriter[source][source]#
- Bases: - object- Base class for writers that obtain events from - EventStorageand process them.
- class direct.utils.events.HistoryBuffer(max_length=1000000)[source][source]#
- Bases: - object- Track a series of scalar values and provide access to smoothed values over a window or the global average of the series. - avg(window_size)[source][source]#
- Return the mean of the latest window_size values in the buffer. - Return type:
- float
 
 - global_avg()[source][source]#
- Return the mean of all the elements in the buffer. - Note that this includes those getting removed due to limited buffer storage. - Return type:
- float
 
 - median(window_size)[source][source]#
- Return the median of the latest window_size values in the buffer. - Return type:
- float
 
 
- class direct.utils.events.JSONWriter(json_file, window_size=2)[source][source]#
- Bases: - EventWriter- Write scalars to a json file. - It saves scalars as one json per line (instead of a big json) for easy parsing. - Examples parsing such a json file: - $ cat metrics.json | jq -s '.[0:2]' [ { "data_time": 0.008433341979980469, "iteration": 20, "loss": 1.9228371381759644, "loss_box_reg": 0.050025828182697296, "loss_classifier": 0.5316952466964722, "loss_mask": 0.7236229181289673, "loss_rpn_box": 0.0856662318110466, "loss_rpn_cls": 0.48198649287223816, "lr": 0.007173333333333333, "time": 0.25401854515075684 }, { "data_time": 0.007216215133666992, "iteration": 40, "loss": 1.282649278640747, "loss_box_reg": 0.06222952902317047, "loss_classifier": 0.30682939291000366, "loss_mask": 0.6970193982124329, "loss_rpn_box": 0.038663312792778015, "loss_rpn_cls": 0.1471673548221588, "lr": 0.007706666666666667, "time": 0.2490077018737793 } ] $ cat metrics.json | jq '.loss_mask' 0.7126231789588928 0.689423680305481 0.6776131987571716 ...
- class direct.utils.events.TensorboardWriter(log_dir, window_size=20, **kwargs)[source][source]#
- Bases: - EventWriter- Write all scalars to a tensorboard file. 
- direct.utils.events.get_event_storage()[source][source]#
- Returns:
- The - EventStorageobject that’s currently being used. Throws an error if no :class`EventStorage` is currently enabled.
 
direct.utils.imports module#
General utilities for module imports.
direct.utils.io module#
- class direct.utils.io.ArrayEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source][source]#
- Bases: - JSONEncoder- default(obj)[source][source]#
- Implement this method in a subclass such that it returns a serializable object for - o, or calls the base implementation (to raise a- TypeError).- For example, to support arbitrary iterators, you could implement default like this: - def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) # Let the base class default method raise the TypeError return super().default(o) 
 
- direct.utils.io.check_is_valid_url(path)[source][source]#
- Check if the given path is a valid url. - Parameters:
- path: PathOrString
 
- Returns:
- Bool describing if this is an URL or not.
 
- Return type:
- bool
 
- direct.utils.io.download_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False)[source][source]#
- Return type:
- None
 
- direct.utils.io.download_url(url, root, filename=None, md5=None, max_redirect_hops=3)[source][source]#
- Download a file from a url and place it in root. - Parameters:
- url: str
- URL to download file from 
- root: PathOrString
- Directory to place downloaded file in 
- filename: str, optional:
- Name to save the file under. If None, use the basename of the URL 
- md5: str, optional
- MD5 checksum of the download. If None, do not check 
- max_redirect_hops: int, optional)
- Maximum number of redirect hops allowed 
 
- Return type:
- None
 
- direct.utils.io.extract_archive(from_path, to_path=None, remove_finished=False)[source][source]#
- Extract an archive. - The archive type and a possible compression is automatically detected from the file name. If the file is compressed but not an archive the call is dispatched to - decompress().- Parameters:
- from_path: str
- Path to the file to be extracted. 
- to_path: str
- Path to the directory the file will be extracted to. If omitted, the directory of the file is used. 
- remove_finished: bool
- If - True, remove the file after the extraction.
 
- Returns:
- str
- Path to the directory the file was extracted to. 
 
- Return type:
- str
 
- direct.utils.io.read_json(fn)[source][source]#
- Read file and output dict, or take dict and output dict. - Parameters:
- fn: Union[Dict, str, pathlib.Path]
 
- Returns:
- dict
 
- Return type:
- Dict
 
- direct.utils.io.read_list(fn)[source][source]#
- Read file and output list, or take list and output list. Can read data from URLs. - Parameters:
- fn: Union[[list, str, pathlib.Path]]
- Input text file or list, or a URL to a text file. 
 
- Returns:
- list
- Text file read line by line. 
 
- Return type:
- List
 
- direct.utils.io.read_text_from_url(url, chunk_size=1024)[source][source]#
- Read a text file from a URL, e.g. a config file. - Parameters:
- url: str
- chunk_size: int
 
- Returns:
- str
- Data from URL 
 
 
- direct.utils.io.upload_to_s3(filename, to_filename, aws_access_key_id, aws_secret_access_key, endpoint_url='https://s3.aiforoncology.nl', bucket='direct-project', verbose=True)[source][source]#
- Upload file to an s3 bucket. - Parameters:
- filenamepathlib.Path
- Filename to upload 
- to_filenamestr
- Where to store the file 
- aws_access_key_idstr
- aws_secret_access_keystr
- endpoint_urlstr
- AWS endpoint url 
- bucketstr
- Bucket name 
- verbosestr
- Show upload progress 
 
- Returns:
- None
 
- Return type:
- None
 
direct.utils.logging module#
- direct.utils.logging.setup(use_stdout=True, filename=None, log_level='INFO')[source][source]#
- Setup logging for DIRECT. - Parameters:
- use_stdout: bool
- Write output to standard out. 
- filename: PathLike
- Filename to write log to. 
- log_level: str
- Logging level as in the python.logging library. 
 
- Returns:
- None
 
- Return type:
- None
 
direct.utils.models module#
- direct.utils.models.fix_state_dict_module_prefix(state_dict)[source][source]#
- If models are saved after being wrapped in e.g. DataParallel, the keys of the state dict are prefixed with module.. This function removes this prefix. - Parameters:
- state_dict: dict
- state_dict of a network module 
- Returns
- ——-
- dict
 
 
direct.utils.writers module#
- direct.utils.writers.write_output_to_h5(output, output_directory, volume_processing_func=None, output_key='reconstruction', create_dirs_if_needed=True)[source][source]#
- Write dictionary with keys filenames and values torch tensors to h5 files. - Parameters:
- output: dict
- Dictionary with keys filenames and values torch.Tensor’s with shape [depth, num_channels, …] where num_channels is typically 1 for MRI. 
- output_directory: pathlib.Path
- volume_processing_func: callable
- Function which postprocesses the volume array before saving. 
- output_key: str
- Name of key to save the output to. 
- create_dirs_if_needed: bool
- If true, the output directory and all its parents will be created. 
 
- Return type:
- None
 - Notes - Currently only num_channels = 1 is supported. If you run this function with more channels the first one will be used. 
Module contents#
direct.utils module.
- class direct.utils.DirectModule[source][source]#
- Bases: - DirectTransform,- ABC,- Module- forward(sample)[source][source]#
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class direct.utils.DirectTransform[source][source]#
- Bases: - object- Direct transform class. - Defines - __repr__()method for Direct transforms.
- class direct.utils.SpatialDims(TWO_D, THREE_D)[source]#
- Bases: - tuple- THREE_D#
- Alias for field number 1 
 - TWO_D#
- Alias for field number 0 
 
- direct.utils.cast_as_path(data)[source][source]#
- Ensure the the input is a path. - Parameters:
- data: str or pathlib.Path
 
- Returns:
- pathlib.Path
 
- Return type:
- Optional[- Path]
 
- direct.utils.chunks(list_to_chunk, number_of_chunks)[source][source]#
- Yield number_of_chunks number of sequential chunks from list_to_chunk. Adapted from [1]. - Parameters:
- list_to_chunk: List
- number_of_chunks: int
 
 - References 
- direct.utils.count_parameters(models)[source][source]#
- Count the number of parameters of a dictionary of models. - Parameters:
- models: Dict
- Dictionary mapping model name to model. 
 
- Return type:
- None
 
- direct.utils.detach_dict(data, keys=None)[source][source]#
- Return a detached copy of a dictionary. Only torch.Tensor’s are detached. - Parameters:
- data: Dict[str, torch.Tensor]
- keys: List, Tuple
- Subselection of keys to detach 
 
- Returns:
- Dictionary at the new device.
 
- Return type:
- Dict
 
- direct.utils.dict_flatten(in_dict, dict_out=None)[source][source]#
- Flattens a nested dictionary (or DictConfig) and returns a new flattened dictionary. - If a dict_out is provided, the flattened dictionary will be added to it. - Parameters:
- in_dictDictOrDictConfig
- The nested dictionary or DictConfig to flatten. 
- dict_outOptional[DictOrDictConfig], optional
- An existing dictionary to add the flattened dictionary to. Default: None. 
 
- Returns:
- Dict[str, Any]
- The flattened dictionary. 
 
- Return type:
- Dict[- str,- Any]
 - Notes - This function only keeps the final keys, and discards the intermediate ones. 
 - Examples - >>> dictA = {"a": 1, "b": {"c": 2, "d": 3, "e": {"f": 4, 6: "a", 5: {"g": 6}, "l": [1, "two"]}}} >>> dict_flatten(dictA) {'a': 1, 'c': 2, 'd': 3, 'f': 4, 6: 'a', 'g': 6, 'l': [1, 'two']} 
- direct.utils.dict_to_device(data, device, keys=None)[source][source]#
- Copy tensor-valued dictionary to device. Only torch.Tensor is copied. - Parameters:
- data: Dict[str, torch.Tensor]
- device: torch.device, str
- keys: List, Tuple
- Subselection of keys to copy. 
 
- Returns:
- Dictionary at the new device.
 
- Return type:
- Dict
 
- direct.utils.ensure_list(data)[source][source]#
- Ensure input is a list. - Parameters:
- data: object
 
- Returns:
- list
 
- Return type:
- List
 
- direct.utils.evaluate_dict(fns_dict, source, target, reduction='mean')[source][source]#
- Evaluate a dictionary of functions. - Parameters:
- fns_dict: Dict[str, Callable]
- source: torch.Tensor
- target: torch.Tensor
- reduction: str
 
- Returns:
- Dict[str, torch.Tensor]
- Evaluated dictionary. 
 
- Return type:
- Dict
 - Examples - > evaluate_dict({‘l1_loss: F.l1_loss, ‘l2_loss’: F.l2_loss}, a, b) - Will return > {‘l1_loss’, F.l1_loss(a, b, reduction=’mean’), ‘l2_loss’: F.l2_loss(a, b, reduction=’mean’) 
- direct.utils.git_hash()[source][source]#
- Returns the current git hash. - Returns:
- _git_hash: str
- The current git hash. 
 
- Return type:
- str
 
- direct.utils.is_complex_data(data, complex_axis=-1)[source][source]#
- Returns True if data is a complex tensor at a specified dimension, i.e. complex_axis of data is of size 2, corresponding to real and imaginary channels.. - Parameters:
- data: torch.Tensor
- complex_axis: int
- Complex dimension along which the check will be done. Default: -1 (last). 
 
- Returns:
- bool
- True if data is a complex tensor. 
 
- Return type:
- bool
 
- direct.utils.is_power_of_two(number)[source][source]#
- Check if input is a power of 2. - Parameters:
- number: int
 
- Returns:
- bool
 
- Return type:
- bool
 
- direct.utils.merge_list_of_dicts(list_of_dicts)[source][source]#
- A list of dictionaries is merged into one dictionary. - Parameters:
- list_of_dicts: List[Dict]
 
- Returns:
- Dict
 
- Return type:
- Dict
 
- direct.utils.multiply_function(multiplier, func)[source][source]#
- Create a function which multiplier another one with a multiplier. - Parameters:
- multiplier: float
- Number to multiply with. 
- func: callable
- Function to multiply. 
 
- Returns:
- return_func: Callable
 
- Return type:
- Callable
 
- direct.utils.normalize_image(image, eps=1e-05)[source][source]#
- Normalize image to range [0,1] for visualization. - Given image \(x\) and \(\epsilon\), it returns: :rtype: - Tensor\[\frac{x - \min{x}}{\max{x} + \epsilon}.\]- Parameters:
- image: torch.Tensor
- Image to scale. 
- eps: float
 
- Returns:
- image: torch.Tensor
- Scaled image. 
 
 
- direct.utils.prefix_dict_keys(data, prefix)[source][source]#
- Append a prefix to a dictionary keys. - Parameters:
- data: Dict[str, Any]
- prefix: str
 
- Returns:
- Dict[str, Any]
 
- Return type:
- Dict[- str,- Any]
 
- direct.utils.reduce_list_of_dicts(data, mode='average', divisor=None)[source][source]#
- Average a list of dictionary mapping keys to Tensors. - Parameters:
- data: List[Dict[str, torch.Tensor]])
- mode: str
- Which reduction mode, average reduces the dictionary, sum just adds while average computes the average. 
- divisor: None or int
- If given values are divided by this factor. 
 
- Returns:
- Dict[str, torch.Tensor]: Reduced dictionary.
 
- Return type:
- Dict[- str,- Tensor]
 
- direct.utils.remove_keys(input_dict, keys)[source][source]#
- Removes keys from input_dict. - Parameters:
- input_dict: Dict
- keys: Union[str, List[str], Tuple[str]]
 
- Returns:
- Dict
 
- Return type:
- Dict
 
- direct.utils.set_all_seeds(seed)[source][source]#
- Sets seed for deterministic runs. - Parameters:
- seed: int
- Seed for random module. 
 
- Returns:
- Return type:
- None
 
- direct.utils.str_to_class(module_name, function_name)[source][source]#
- Convert a string to a class Base on: https://stackoverflow.com/a/1176180/576363. - Also support function arguments, e.g. ifft(dim=2) will be parsed as a partial and return ifft where dim has been set to 2. - Parameters:
- module_name: str
- e.g. direct.data.transforms 
- function_name: str
- e.g. Identity 
- Returns
- ——-
- object
 
- Return type:
- Callable
 - Examples - >>> def mult(f, mul=2): >>> return f*mul - >>> str_to_class(".", "mult(mul=4)") >>> str_to_class(".", "mult(mul=4)") will return a function which multiplies the input times 4, while - >>> str_to_class(".", "mult") just returns the function itself.