direct.utils package#
Submodules#
direct.utils.asserts module#
- direct.utils.asserts.assert_positive_integer(*variables, strict=False)[source]#
Assert if given variables are positive integer.
- Parameters:
*variables – Variables to check.
strict (
bool) – IfTrue, will allow zero values. Default:False.
- Return type:
None
- direct.utils.asserts.assert_same_shape(data_list)[source]#
Check if all tensors in the list have the same shape.
- Parameters:
data_list (
List[Tensor]) – List of tensors.
- direct.utils.asserts.assert_complex(data, complex_axis=-1, complex_last=None)[source]#
Assert if a tensor is complex (has complex dimension of size 2 corresponding to real and imaginary channels).
- Parameters:
data (
Tensor) – Tensor to check.complex_axis (
int) – Complex dimension along which the assertion will be done. Default:-1(last).complex_last (
Optional[bool]) – IfTrue, will overridecomplex_axiswith-1(last). Default:None.
- Return type:
None
direct.utils.bbox module#
- direct.utils.bbox.crop_to_bbox(data, bbox, pad_value=0)[source]#
Extract bbox from images, coordinates can be negative.
- Parameters:
data (
Union[ndarray,Tensor]) – nD array or torch tensor.bbox (
List[int]) – Bounding box of the form(coordinates, size). For instance(4, 4, 2, 1)is a patch starting at row 4, col 4 with height 2 and width 1.pad_value (
int) – If bounding box would be out of the image, this is value the patch will be padded with. Default:0.
- Return type:
Union[ndarray,Tensor]- Returns:
Numpy array of data cropped to bounding box.
- direct.utils.bbox.crop_to_largest(data, pad_value=0)[source]#
Given a list of arrays or tensors, return the same list with the data padded to the largest in the set.
Can be convenient for e.g. logging and tiling several images as with torchvision’s
make_grid.- Parameters:
data (
List[Union[ndarray,Tensor]]) – List of numpy arrays or torch tensors.pad_value (
int) – Padding value. Default:0.
- Return type:
List[Union[ndarray,Tensor]]- Returns:
List of padded arrays or tensors.
direct.utils.communication module#
- direct.utils.communication.synchronize()[source]#
Synchronize processes between GPUs.
Wait until all devices are available. Does nothing in a non-distributed setting.
- direct.utils.communication.get_rank()[source]#
Get rank of the process, even when torch.distributed is not initialized.
- Return type:
int- Returns:
Process rank.
- direct.utils.communication.get_local_rank()[source]#
Get rank of the process within the same machine, even when torch.distributed is not initialized.
- Return type:
int- Returns:
The rank of the current process within the local (per-machine) process group.
- direct.utils.communication.get_local_size()[source]#
Number of compute units in local machine.
- Return type:
int- Returns:
Number of compute units.
- direct.utils.communication.is_main_process()[source]#
Simple wrapper around get_rank().
- Return type:
bool- Returns:
Trueif current process is the main process.
- direct.utils.communication.get_world_size()[source]#
Get number of compute device in the world, returns 1 in case multi device is not initialized.
- Return type:
int- Returns:
World size.
- direct.utils.communication.all_gather(data, group=None)[source]#
Run all_gather on arbitrary picklable data (not necessarily tensors).
- Parameters:
data (
object) – objectobject. (Any pickleable)
group (
Optional[group])default (A torch process group. By)
backend. (will use a group which contains all ranks on gloo)
- Returns:
list of data gathered for each rank
- Return type:
list
- direct.utils.communication.gather(data, destination_rank=0, group=None)[source]#
Run gather on arbitrary picklable data (not necessarily tensors).
- Parameters:
data (
object) – objectobject (Any pickleable)
destination_rank (
int) – intrank (Destination)
group (
Optional[group])default (A torch process group. By)
backend. (will use a group which contains all ranks on gloo)
- Returns:
on destination_rank, a list of data gathered from each rank. Otherwise, an empty list.
- Return type:
list[data]
All workers must call this function, otherwise it will deadlock.
- Return type:
int- Returns:
A random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to create one.
- direct.utils.communication.reduce_tensor_dict(tensors_dict)[source]#
Reduce the tensor dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as tensors_dict, after reduction.
- Parameters:
tensors_dict (
Dict[str,Tensor]) – dicttensors. (dictionary with str keys mapping to torch)
- Returns:
the reduced dict.
- Return type:
dict
direct.utils.dataset module#
- direct.utils.dataset.get_filenames_for_datasets_from_config(cfg, files_root, data_root)[source]#
Given a configuration object it returns a list of filenames.
- Parameters:
cfg – Configuration object having property lists having the relative paths compared to files root.
files_root (
Union[Path,str]) – Files root path.data_root (
Path) – Data root path.
- Returns:
List of filenames or
None.
- direct.utils.dataset.get_filenames_for_datasets(lists, files_root, data_root)[source]#
Given lists of filenames of data points, concatenate these into a large list of full filenames.
- Parameters:
lists (
List[Union[Path,str]]) – List of filename list paths.files_root (
Union[Path,str]) – Files root path.data_root (
Path) – Data root path.
- Returns:
List of filenames or
None.
direct.utils.events module#
- direct.utils.events.get_event_storage()[source]#
Get the current event storage.
- Returns:
The
EventStorageobject that’s currently being used.- Raises:
ValueError – If no
EventStorageis currently enabled.
- class direct.utils.events.EventWriter[source]#
Bases:
objectBase class for writers that obtain events from
EventStorageand process them.
- class direct.utils.events.JSONWriter(json_file, window_size=2)[source]#
Bases:
EventWriterWrite scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
$ cat metrics.json | jq -s '.[0:2]' [ { "data_time": 0.008433341979980469, "iteration": 20, "loss": 1.9228371381759644, "loss_box_reg": 0.050025828182697296, "loss_classifier": 0.5316952466964722, "loss_mask": 0.7236229181289673, "loss_rpn_box": 0.0856662318110466, "loss_rpn_cls": 0.48198649287223816, "lr": 0.007173333333333333, "time": 0.25401854515075684 }, { "data_time": 0.007216215133666992, "iteration": 40, "loss": 1.282649278640747, "loss_box_reg": 0.06222952902317047, "loss_classifier": 0.30682939291000366, "loss_mask": 0.6970193982124329, "loss_rpn_box": 0.038663312792778015, "loss_rpn_cls": 0.1471673548221588, "lr": 0.007706666666666667, "time": 0.2490077018737793 } ] $ cat metrics.json | jq '.loss_mask' 0.7126231789588928 0.689423680305481 0.6776131987571716 ...- __init__(json_file, window_size=2)[source]#
Inits
JSONWriter.- Parameters:
json_file (
Union[Path,str]) – Path to the JSON file. Data will be appended if it exists.window_size (
int) – Window size of median smoothing for variables for whichsmoothing_hintisTrue. Default:2.
- class direct.utils.events.TensorboardWriter(log_dir, window_size=20, **kwargs)[source]#
Bases:
EventWriterWrite all scalars to a tensorboard file.
- __init__(log_dir, window_size=20, **kwargs)[source]#
Inits
TensorboardWriter.- Parameters:
log_dir (
Union[Path,str]) – The directory to save the output events.window_size (
int) – The scalars will be median-smoothed by this window size. Default:20.**kwargs – Other arguments passed to
torch.utils.tensorboard.SummaryWriter.
- class direct.utils.events.CommonMetricPrinter(max_iter)[source]#
Bases:
EventWriterPrint common metrics to the terminal, including iteration time, ETA, memory, all losses, and the learning rate.
To print something different, please implement a similar printer by yourself.
- __init__(max_iter)[source]#
Inits
CommonMetricPrinter.- Parameters:
max_iter – The maximum number of iterations to train. Used to compute ETA.
- class direct.utils.events.EventStorage(start_iter=0)[source]#
Bases:
objectThe user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
- add_image(img_name, img_tensor)[source]#
Add an img_tensor to the _vis_data associated with img_name.
- Parameters:
img_name – str
tensorboard. (The name of the input_image to put into)
img_tensor – torch.Tensor or numpy.array
`[channel (An uint8 or float Tensor of shape) –
height (float32) or [0, 255] (uint8)
[0 (width]` where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in)
1] (float32) or [0, 255] (uint8)
- clear_images()[source]#
Delete all the stored images for visualization.
This should be called after images are written to tensorboard.
- add_scalar(name, value, smoothing_hint=True)[source]#
Add a scalar value to the HistoryBuffer associated with name.
- Parameters:
name – str
value – float
smoothing_hint – bool
signal. (A 'hint' on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through EventStorage.smoothing_hints. A writer may ignore the hint and apply custom smoothing rule. It defaults to True because most scalars we save need to be smoothed to provide any useful)
Returns:
- add_scalars(*, smoothing_hint=True, **kwargs)[source]#
Put multiple scalars from keyword arguments.
Examples
storage.add_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
- add_graph(img_name, img_tensor)[source]#
Add an img_tensor to the _vis_data associated with img_name.
- Parameters:
img_name – str
tensorboard. (The name of the input_image to put into)
img_tensor – torch.Tensor or numpy.array
`[channel (An uint8 or float Tensor of shape) –
height (float32) or [0, 255] (uint8)
[0 (width]` where channel is 3. The input_image format should be RGB. The elements in img_tensor can either have values in)
1] (float32) or [0, 255] (uint8)
- histories()[source]#
- Returns:
the HistoryBuffer for all scalars
- Return type:
dict[name -> HistoryBuffer]
- latest()[source]#
- Returns:
the scalars that’s added in the current iteration.
- Return type:
dict[name -> number]
- latest_with_smoothing_hint(window_size=20)[source]#
Similar to
latest(), but the returned values are either the un-smoothed original latest value, or a median of the given window_size, depend on whether the smoothing_hint is True.This provides a default behavior that other writers can use.
- smoothing_hints()[source]#
- Returns:
- the user-provided hint on whether the scalar
is noisy and needs smoothing.
- Return type:
dict[name -> bool]
- step()[source]#
User should call this function at the beginning of each iteration, to notify the storage of the start of a new iteration.
The storage will then be able to associate the new data with the correct iteration number.
- property vis_data#
- property iter#
- class direct.utils.events.HistoryBuffer(max_length=1000000)[source]#
Bases:
objectTrack a series of scalar values and provide access to smoothed values over a window or the global average of the series.
- __init__(max_length=1000000)[source]#
- Parameters:
max_length (
int) – maximal number of values that can be stored in the buffer. When the capacity of the buffer is exhausted, old values will be removed.
- update(value, iteration=None)[source]#
Add a new scalar value produced at certain iteration.
If the length of the buffer exceeds self._max_length, the oldest element will be removed from the buffer.
- Return type:
None
- median(window_size)[source]#
Return the median of the latest window_size values in the buffer.
- Return type:
float
- avg(window_size)[source]#
Return the mean of the latest window_size values in the buffer.
- Return type:
float
direct.utils.imports module#
General utilities for module imports.
direct.utils.io module#
- direct.utils.io.read_json(fn)[source]#
Read file and output dict, or take dict and output dict.
- Parameters:
fn (
Union[Dict,str,Path]) – File path or dictionary.- Return type:
Dict- Returns:
Dictionary content.
- class direct.utils.io.ArrayEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]#
Bases:
JSONEncoder- default(obj)[source]#
Implement this method in a subclass such that it returns a serializable object for
o, or calls the base implementation (to raise aTypeError).For example, to support arbitrary iterators, you could implement default like this:
def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) # Let the base class default method raise the TypeError return super().default(o)
- direct.utils.io.write_json(fn, data, indent=2)[source]#
Write dict data to file.
- Parameters:
fn (
Union[str,Path]) – File path.data (
Dict) – Dictionary to write.indent – Indentation level. Default:
2.
- Return type:
None
- direct.utils.io.read_list(fn)[source]#
Read file and output list, or take list and output list. Can read data from URLs.
- Parameters:
fn (
Union[List,str,Path]) – Input text file or list, or a URL to a text file.- Return type:
List- Returns:
Text file read line by line.
- direct.utils.io.write_list(fn, data)[source]#
Write list line by line to file.
- Parameters:
fn (
Union[str,Path]) – Output file path.data – List or tuple to write.
- Return type:
None
- direct.utils.io.download_url(url, root, filename=None, md5=None, max_redirect_hops=3)[source]#
Download a file from a url and place it in root.
- Parameters:
url (
str) – URL to download file fromroot (
Union[Path,str]) – Directory to place downloaded file infilename (
Optional[str]) – Name to save the file under. If None, use the basename of the URLmd5 (
Optional[str]) – MD5 checksum of the download. If None, do not checkmax_redirect_hops (
int) – Maximum number of redirect hops allowed
- Return type:
None
- direct.utils.io.extract_archive(from_path, to_path=None, remove_finished=False)[source]#
Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed but not an archive the call is dispatched to
decompress().- Parameters:
from_path (
str) – Path to the file to be extracted.to_path (
Optional[str]) – Path to the directory the file will be extracted to. If omitted, the directory of the file is used.remove_finished (
bool) – IfTrue, remove the file after the extraction.
- Return type:
str- Returns:
Path to the directory the file was extracted to.
- direct.utils.io.download_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False)[source]#
- Return type:
None
- direct.utils.io.read_text_from_url(url, chunk_size=1024)[source]#
Read a text file from a URL, e.g. a config file.
- Parameters:
url – str
chunk_size (
int) – int
- Returns:
Data from URL
- direct.utils.io.check_is_valid_url(path)[source]#
Check if the given path is a valid url.
- Parameters:
path (
Union[Path,str]) – PathOrString- Return type:
bool- Returns:
Bool describing if this is an URL or not.
- direct.utils.io.upload_to_s3(filename, to_filename, aws_access_key_id, aws_secret_access_key, endpoint_url='https://s3.aiforoncology.nl', bucket='direct-project', verbose=True)[source]#
Upload file to an s3 bucket.
- Parameters:
filename (
Path) – Filename to uploadto_filename (
str) – Where to store the fileaws_access_key_id (
str) – straws_secret_access_key (
str) – strendpoint_url (
str) – AWS endpoint urlbucket (
str) – Bucket nameverbose (
bool) – Show upload progress
- Return type:
None- Returns:
None
direct.utils.logging module#
- direct.utils.logging.setup(use_stdout=True, filename=None, log_level='INFO')[source]#
Setup logging for DIRECT.
- Parameters:
use_stdout (
Optional[bool]) – Write output to standard out. Default:True.filename (
Optional[PathLike]) – Filename to write log to. Default:None.log_level (
Union[int,str]) – Logging level as in thepython.logginglibrary. Default:"INFO".
- Return type:
None
direct.utils.models module#
- direct.utils.models.fix_state_dict_module_prefix(state_dict)[source]#
If models are saved after being wrapped in e.g. DataParallel, the keys of the state dict are prefixed with
module..This function removes this prefix.
- Parameters:
state_dict – State dict of a network module.
- Returns:
State dict with fixed keys.
direct.utils.writers module#
- direct.utils.writers.write_output_to_h5(output, output_directory, volume_processing_func=None, output_key='reconstruction', create_dirs_if_needed=True)[source]#
Write dictionary with keys filenames and values torch tensors to h5 files.
- Parameters:
output (
Union[Dict,DefaultDict]) – Dictionary with keys filenames and valuestorch.Tensor’s with shape[depth, num_channels, ...]wherenum_channelsis typically1for MRI.output_directory (
Path) – Output directory path.volume_processing_func (
Optional[Callable]) – Function which postprocesses the volume array before saving. Default:None.output_key (
str) – Name of key to save the output to. Default:"reconstruction".create_dirs_if_needed (
bool) – IfTrue, the output directory and all its parents will be created. Default:True.
- Return type:
None
Note
Currently only
num_channels = 1is supported. If you run this function with more channels the first one will be used.
Module contents#
direct.utils module.
- direct.utils.is_complex_data(data, complex_axis=-1)[source]#
Returns True if data is a complex tensor at a specified dimension.
I.e.
complex_axisof data is of size 2, corresponding to real and imaginary channels.- Parameters:
data (
Tensor) – Input tensor.complex_axis (
int) – Complex dimension along which the check will be done. Default:-1(last).
- Return type:
bool- Returns:
Trueif data is a complex tensor.
- direct.utils.is_power_of_two(number)[source]#
Check if input is a power of 2.
- Parameters:
number (
int) – Integer to check.- Return type:
bool- Returns:
Trueif number is a power of 2.
- direct.utils.ensure_list(data)[source]#
Ensure input is a list.
- Parameters:
data (
Any) – Object to convert to list.- Return type:
List- Returns:
List representation of input.
- direct.utils.cast_as_path(data)[source]#
Ensure the the input is a path.
- Parameters:
data (
Union[Path,str,None]) – String or path to convert.- Return type:
Optional[Path]- Returns:
Path object or
Noneif input isNone.
- direct.utils.str_to_class(module_name, function_name)[source]#
Convert a string to a class.
Base on: https://stackoverflow.com/a/1176180/576363.
Also support function arguments, e.g.
ifft(dim=2)will be parsed as a partial and returnifftwheredimhas been set to2.- Parameters:
module_name (
str) – Module name, e.g."direct.data.transforms".function_name (
str) – Function or class name, e.g."Identity".
- Return type:
Callable- Returns:
The requested class or function object.
Example
>>> def mult(f, mul=2): ... return f*mul >>> str_to_class(".", "mult(mul=4)")
Will return a function which multiplies the input times 4, while
>>> str_to_class(".", "mult")
just returns the function itself.
- direct.utils.dict_to_device(data, device, keys=None)[source]#
Copy tensor-valued dictionary to device.
Only
torch.Tensoris copied.- Parameters:
data (
Dict[str,Tensor]) – Dictionary with tensor values.device (
Union[device,str,None]) – Target device.keys (
Union[List,Tuple,KeysView,None]) – Subselection of keys to copy. Default:None.
- Return type:
Dict- Returns:
Dictionary at the new device.
- direct.utils.detach_dict(data, keys=None)[source]#
Return a detached copy of a dictionary.
Only
torch.Tensor’s are detached.- Parameters:
data (
Dict[str,Tensor]) – Dictionary with tensor values.keys (
Union[List,Tuple,KeysView,None]) – Subselection of keys to detach. Default:None.
- Return type:
Dict- Returns:
Detached dictionary.
- direct.utils.reduce_list_of_dicts(data, mode='average', divisor=None)[source]#
Average a list of dictionary mapping keys to Tensors.
- Parameters:
data (
List[Dict[str,Tensor]]) – List of dictionaries with tensor values.mode – Which reduction mode.
"average"computes the average,"sum"just adds. Default:"average".divisor – If given, values are divided by this factor. Default:
None.
- Return type:
Dict[str,Tensor]- Returns:
Reduced dictionary.
- direct.utils.merge_list_of_dicts(list_of_dicts)[source]#
A list of dictionaries is merged into one dictionary.
- Parameters:
list_of_dicts (
List[Dict]) – List of dictionaries to merge.- Return type:
Dict- Returns:
Merged dictionary.
- direct.utils.evaluate_dict(fns_dict, source, target, reduction='mean')[source]#
Evaluate a dictionary of functions.
- Parameters:
fns_dict (
Dict[str,Callable]) – Dictionary mapping names to callable functions.source (
Tensor) – Source tensor.target (
Tensor) – Target tensor.reduction (
str) – Reduction mode. Default:"mean".
- Return type:
Dict- Returns:
Evaluated dictionary.
Example
>>> evaluate_dict({'l1_loss': F.l1_loss, 'l2_loss': F.l2_loss}, a, b)
Will return
>>> {'l1_loss': F.l1_loss(a, b, reduction='mean'), 'l2_loss': F.l2_loss(a, b, reduction='mean')}
- direct.utils.prefix_dict_keys(data, prefix)[source]#
Append a prefix to a dictionary keys.
- Parameters:
data (
Dict[str,Any]) – Dictionary to prefix.prefix (
str) – Prefix string to add.
- Return type:
Dict[str,Any]- Returns:
Dictionary with prefixed keys.
- direct.utils.git_hash()[source]#
Returns the current git hash.
- Return type:
str- Returns:
The current git hash.
- direct.utils.normalize_image(image, eps=0.00001)[source]#
Normalize image to range [0,1] for visualization.
Given image \(x\) and \(\epsilon\), it returns:
\[\frac{x - \min{x}}{\max{x} + \epsilon}.\]- Parameters:
image (
Tensor) – Image to scale.eps (
float) – Epsilon value to avoid division by zero. Default:0.00001.
- Return type:
Tensor- Returns:
Scaled image.
- direct.utils.multiply_function(multiplier, func)[source]#
Create a function which multiplies another one with a multiplier.
- Parameters:
multiplier (
float) – Number to multiply with.func (
Callable) – Function to multiply.
- Return type:
Callable- Returns:
Multiplied function.
- class direct.utils.SpatialDims(TWO_D, THREE_D)#
Bases:
tuple- THREE_D#
Alias for field number 1
- TWO_D#
Alias for field number 0
- class direct.utils.DirectTransform[source]#
Bases:
objectDirect transform class.
Defines
__repr__()method for Direct transforms.- __init__()[source]#
Inits
DirectTransform.
- __repr__()[source]#
Representation of
DirectTransform.
- class direct.utils.DirectModule[source]#
Bases:
DirectTransform,ABC,Module- abstract __init__()[source]#
Inits
DirectTransform.
- forward(sample)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- direct.utils.count_parameters(models)[source]#
Count the number of parameters of a dictionary of models.
- Parameters:
models (
Dict) – Dictionary mapping model name to model.- Return type:
None
- direct.utils.set_all_seeds(seed)[source]#
Sets seed for deterministic runs.
- Parameters:
seed (
int) – Seed for random module.- Return type:
None
- direct.utils.chunks(list_to_chunk, number_of_chunks)[source]#
Yield
number_of_chunksnumber of sequential chunks fromlist_to_chunk.Adapted from [1].
- Parameters:
list_to_chunk (
List) – List to chunk.number_of_chunks (
int) – Number of chunks to create.
References
- direct.utils.remove_keys(input_dict, keys)[source]#
Removes
keysfrominput_dict.- Parameters:
input_dict (
Dict) – Input dictionary.keys (
Union[str,List[str],Tuple[str]]) – Key or keys to remove.
- Return type:
Dict- Returns:
Dictionary with keys removed.
- direct.utils.dict_flatten(in_dict, dict_out=None)[source]#
Flattens a nested dictionary (or
DictConfig) and returns a new flattened dictionary.If a
dict_outis provided, the flattened dictionary will be added to it.- Parameters:
in_dict (
Union[dict,DictConfig]) – The nested dictionary orDictConfigto flatten.dict_out (
Union[dict,DictConfig,None]) – An existing dictionary to add the flattened dictionary to. Default:None.
- Return type:
Dict[str,Any]- Returns:
The flattened dictionary.
Note
This function only keeps the final keys, and discards the intermediate ones.
Example
>>> dictA = {"a": 1, "b": {"c": 2, "d": 3, "e": {"f": 4, 6: "a", 5: {"g": 6}, "l": [1, "two"]}}} >>> dict_flatten(dictA) {'a': 1, 'c': 2, 'd': 3, 'f': 4, 6: 'a', 'g': 6, 'l': [1, 'two']}