Class: Torch::NN::RNNBase

Inherits:
Module
  • Object
show all
Defined in:
lib/torch/nn/rnn_base.rb

Direct Known Subclasses

GRU, LSTM, RNN

Instance Attribute Summary

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#add_module, #apply, #buffers, #call, #children, #cpu, #cuda, #deep_dup, #double, #eval, #float, #half, #inspect, #load_state_dict, #method_missing, #modules, #named_buffers, #named_children, #named_modules, #named_parameters, #parameters, #register_buffer, #register_parameter, #requires_grad!, #respond_to?, #share_memory, #state_dict, #to, #train, #type, #zero_grad

Methods included from Utils

#_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initialize(mode, input_size, hidden_size, num_layers: 1, bias: true, batch_first: false, dropout: 0.0, bidirectional: false) ⇒ RNNBase

Returns a new instance of RNNBase.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# File 'lib/torch/nn/rnn_base.rb', line 4

def initialize(mode, input_size, hidden_size, num_layers: 1, bias: true,
  batch_first: false, dropout: 0.0, bidirectional: false)

  super()
  @mode = mode
  @input_size = input_size
  @hidden_size = hidden_size
  @num_layers = num_layers
  @bias = bias
  @batch_first = batch_first
  @dropout = dropout.to_f
  @bidirectional = bidirectional
  num_directions = bidirectional ? 2 : 1

  if !dropout.is_a?(Numeric) || !(dropout >= 0 && dropout <= 1)
    raise ArgumentError, "dropout should be a number in range [0, 1] " +
                         "representing the probability of an element being " +
                         "zeroed"
  end
  if dropout > 0 && num_layers == 1
    warn "dropout option adds dropout after all but last " +
         "recurrent layer, so non-zero dropout expects " +
         "num_layers greater than 1, but got dropout=#{dropout} and " +
         "num_layers=#{num_layers}"
  end

  gate_size =
    case mode
    when "LSTM"
      4 * hidden_size
    when "GRU"
      3 * hidden_size
    when "RNN_TANH"
      hidden_size
    when "RNN_RELU"
      hidden_size
    else
      raise ArgumentError, "Unrecognized RNN mode: #{mode}"
    end

  @all_weights = []
  num_layers.times do |layer|
    num_directions.times do |direction|
      layer_input_size = layer == 0 ? input_size : hidden_size * num_directions

      w_ih = Parameter.new(Torch::Tensor.new(gate_size, layer_input_size))
      w_hh = Parameter.new(Torch::Tensor.new(gate_size, hidden_size))
      b_ih = Parameter.new(Torch::Tensor.new(gate_size))
      # Second bias vector included for CuDNN compatibility. Only one
      # bias vector is needed in standard definition.
      b_hh = Parameter.new(Torch::Tensor.new(gate_size))
      layer_params = [w_ih, w_hh, b_ih, b_hh]

      suffix = direction == 1 ? "_reverse" : ""
      param_names = ["weight_ih_l%s%s", "weight_hh_l%s%s"]
      if bias
        param_names += ["bias_ih_l%s%s", "bias_hh_l%s%s"]
      end
      param_names.map! { |x| x % [layer, suffix] }

      param_names.zip(layer_params) do |name, param|
        instance_variable_set("@#{name}", param)
      end
      @all_weights << param_names
    end
  end

  flatten_parameters
  reset_parameters
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method in the class Torch::NN::Module

Instance Method Details

#_apply(fn) ⇒ Object



79
80
81
82
83
# File 'lib/torch/nn/rnn_base.rb', line 79

def _apply(fn)
  ret = super
  flatten_parameters
  ret
end

#extra_inspectObject

TODO add more parameters



146
147
148
149
150
151
152
# File 'lib/torch/nn/rnn_base.rb', line 146

def extra_inspect
  s = String.new("%{input_size}, %{hidden_size}")
  if @num_layers != 1
    s += ", num_layers: %{num_layers}"
  end
  format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
end

#flatten_parametersObject



75
76
77
# File 'lib/torch/nn/rnn_base.rb', line 75

def flatten_parameters
  # no-op unless module is on the GPU and cuDNN is enabled
end

#forward(input, hx: nil) ⇒ Object



99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# File 'lib/torch/nn/rnn_base.rb', line 99

def forward(input, hx: nil)
  is_packed = false # TODO isinstance(input, PackedSequence)
  if is_packed
    input, batch_sizes, sorted_indices, unsorted_indices = input
    max_batch_size = batch_sizes[0]
    max_batch_size = max_batch_size.to_i
  else
    batch_sizes = nil
    max_batch_size = @batch_first ? input.size(0) : input.size(1)
    sorted_indices = nil
    unsorted_indices = nil
  end

  if hx.nil?
    num_directions = @bidirectional ? 2 : 1
    hx = Torch.zeros(@num_layers * num_directions, max_batch_size,
      @hidden_size, dtype: input.dtype, device: input.device)
  else
    # Each batch of the hidden state should match the input sequence that
    # the user believes he/she is passing in.
    hx = permute_hidden(hx, sorted_indices)
  end

  check_forward_args(input, hx, batch_sizes)
  _rnn_impls = {
    "RNN_TANH" => Torch.method(:rnn_tanh),
    "RNN_RELU" => Torch.method(:rnn_relu)
  }
  _impl = _rnn_impls[@mode]
  if batch_sizes.nil?
    result = _impl.call(input, hx, _get_flat_weights, @bias, @num_layers,
                     @dropout, @training, @bidirectional, @batch_first)
  else
    result = _impl.call(input, batch_sizes, hx, _get_flat_weights, @bias,
                     @num_layers, @dropout, @training, @bidirectional)
  end
  output = result[0]
  hidden = result[1]

  if is_packed
    raise NotImplementedYet
    # output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  end
  [output, permute_hidden(hidden, unsorted_indices)]
end

#permute_hidden(hx, permutation) ⇒ Object

Raises:



92
93
94
95
96
97
# File 'lib/torch/nn/rnn_base.rb', line 92

def permute_hidden(hx, permutation)
  if permutation.nil?
    return hx
  end
  raise NotImplementedYet
end

#reset_parametersObject



85
86
87
88
89
90
# File 'lib/torch/nn/rnn_base.rb', line 85

def reset_parameters
  stdv = 1.0 / Math.sqrt(@hidden_size)
  parameters.each do |weight|
    Init.uniform!(weight, a: -stdv, b: stdv)
  end
end