direct.nn.recurrent package#

Submodules#

direct.nn.recurrent.recurrent module#

class direct.nn.recurrent.recurrent.Conv2dGRU(in_channels, hidden_channels, out_channels=None, num_layers=2, gru_kernel_size=1, orthogonal_initialization=True, instance_norm=False, dense_connect=0, replication_padding=True)[source][source]#

Bases: Module

2D Convolutional GRU Network.

forward(cell_input, previous_state)[source][source]#

Computes Conv2dGRU forward pass given tensors cell_input and previous_state.

Parameters:
cell_input: torch.Tensor

Input tensor.

previous_state: torch.Tensor

Tensor of previous hidden state.

Returns:
out, new_states: (torch.Tensor, torch.Tensor)

Output and new states.

Return type:

Tuple[Tensor, Tensor]

training: bool#
class direct.nn.recurrent.recurrent.NormConv2dGRU(in_channels, hidden_channels, out_channels=None, num_layers=2, gru_kernel_size=1, orthogonal_initialization=True, instance_norm=False, dense_connect=0, replication_padding=True, norm_groups=2)[source][source]#

Bases: Module

Normalized 2D Convolutional GRU Network.

Normalization methods adapted from NormUnet of [1].

References

forward(cell_input, previous_state)[source][source]#

Computes NormConv2dGRU forward pass given tensors cell_input and previous_state.

It performs group normalization on the input before the forward pass to the Conv2dGRU. Output of Conv2dGRU is then un-normalized.

Parameters:
cell_input: torch.Tensor

Input tensor.

previous_state: torch.Tensor

Tensor of previous hidden state.

Returns:
out, new_states: (torch.Tensor, torch.Tensor)

Output and new states.

Return type:

Tuple[Tensor, Tensor]

static norm(input_data, num_groups)[source][source]#

Performs group normalization.

Return type:

Tuple[Tensor, Tensor, Tensor]

training: bool#
static unnorm(input_data, mean, std, num_groups)[source][source]#
Return type:

Tensor

Module contents#