24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
# File 'lib/torch/nn/lstm.rb', line 24
def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
if hx.nil?
num_directions = @bidirectional ? 2 : 1
zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
hx = [zeros, zeros]
else
hx = permute_hidden(hx, sorted_indices)
end
check_forward_args(input, hx, batch_sizes)
if batch_sizes.nil?
result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
@dropout, @training, @bidirectional, @batch_first)
else
result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
@num_layers, @dropout, @training, @bidirectional)
end
output = result[0]
hidden = result[1..-1]
[output, hidden]
end
|