direct.nn.recurrent package#
Submodules#
direct.nn.recurrent.recurrent module#
- class direct.nn.recurrent.recurrent.Conv2dGRU(in_channels, hidden_channels, out_channels=None, num_layers=2, gru_kernel_size=1, orthogonal_initialization=True, instance_norm=False, dense_connect=0, replication_padding=True)[source][source]#
Bases:
Module
2D Convolutional GRU Network.
- forward(cell_input, previous_state)[source][source]#
Computes Conv2dGRU forward pass given tensors cell_input and previous_state.
- Parameters:
- cell_input: torch.Tensor
Input tensor.
- previous_state: torch.Tensor
Tensor of previous hidden state.
- Returns:
- out, new_states: (torch.Tensor, torch.Tensor)
Output and new states.
- Return type:
Tuple
[Tensor
,Tensor
]
-
training:
bool
#
- class direct.nn.recurrent.recurrent.NormConv2dGRU(in_channels, hidden_channels, out_channels=None, num_layers=2, gru_kernel_size=1, orthogonal_initialization=True, instance_norm=False, dense_connect=0, replication_padding=True, norm_groups=2)[source][source]#
Bases:
Module
Normalized 2D Convolutional GRU Network.
Normalization methods adapted from NormUnet of [1].
References
- forward(cell_input, previous_state)[source][source]#
Computes
NormConv2dGRU
forward pass given tensors cell_input and previous_state.It performs group normalization on the input before the forward pass to the Conv2dGRU. Output of Conv2dGRU is then un-normalized.
- Parameters:
- cell_input: torch.Tensor
Input tensor.
- previous_state: torch.Tensor
Tensor of previous hidden state.
- Returns:
- out, new_states: (torch.Tensor, torch.Tensor)
Output and new states.
- Return type:
Tuple
[Tensor
,Tensor
]
- static norm(input_data, num_groups)[source][source]#
Performs group normalization.
- Return type:
Tuple
[Tensor
,Tensor
,Tensor
]
-
training:
bool
#