Method: Torch::NN::LSTM#forward_impl

Defined in:
lib/torch/nn/lstm.rb

#forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) ⇒ Object



24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# File 'lib/torch/nn/lstm.rb', line 24

def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
  if hx.nil?
    num_directions = @bidirectional ? 2 : 1
    zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
    hx = [zeros, zeros]
  else
    # Each batch of the hidden state should match the input sequence that
    # the user believes he/she is passing in.
    hx = permute_hidden(hx, sorted_indices)
  end

  check_forward_args(input, hx, batch_sizes)
  if batch_sizes.nil?
    result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
                        @dropout, @training, @bidirectional, @batch_first)
  else
    result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
                        @num_layers, @dropout, @training, @bidirectional)
  end
  output = result[0]
  hidden = result[1..-1]

  [output, hidden]
end