Class: Chainer::Variable

Inherits:
Object
  • Object
show all
Defined in:
lib/chainer/variable.rb

Direct Known Subclasses

Parameter

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(data = nil, name: nil, grad: nil, requires_grad: true) ⇒ Variable

Returns a new instance of Variable.



24
25
26
27
28
29
30
31
32
33
34
# File 'lib/chainer/variable.rb', line 24

def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
  unless data.nil? || Chainer.array?(data)
    raise TypeError, "Numo::NArray or Cumo::NArray are expected."
  end

  @data = [data]
  @grad = grad
  @requires_grad = requires_grad
  @node = VariableNode.new(variable: self, name: name)
  @grad_var = grad.nil? ? nil : Chainer::Variable.new(grad)
end

Instance Attribute Details

#dataObject Also known as: array

Returns the value of attribute data.



3
4
5
# File 'lib/chainer/variable.rb', line 3

def data
  @data
end

#gradObject

Returns the value of attribute grad.



3
4
5
# File 'lib/chainer/variable.rb', line 3

def grad
  @grad
end

#nodeObject

Returns the value of attribute node.



3
4
5
# File 'lib/chainer/variable.rb', line 3

def node
  @node
end

#requires_gradObject

Returns the value of attribute requires_grad.



3
4
5
# File 'lib/chainer/variable.rb', line 3

def requires_grad
  @requires_grad
end

Class Method Details

.as_variable(obj) ⇒ Chainer::Variable

Converts an array or a variable into Chainer::Variable. This is a convenient function to get a Chainer::Variable object transparently from a raw array or a variable. Note: that this function should only be used for type consistency (i.e. to enforce the return value of an API having type Chainer::Variable). The Chianer::Variable.requires_grad flag is kept as is; if obj is a raw array, the newly created variable has requires_grad = false. In order to make a variable w.r.t. which you want to compute the gradient, you should use $Chainer::Variable$ directly.

Parameters:

  • obj (Numo::NArray or Chainer::Variable)

    An array or a variable that you want to convert to $Chainer::Variable$.

Returns:

  • (Chainer::Variable)

    A variable converted from obj. If obj is a raw array, this is a new Chianer::Variable object that wraps the array. If obj is already a Chainer::Variable object, this function returns obj as is.



18
19
20
21
22
# File 'lib/chainer/variable.rb', line 18

def self.as_variable(obj)
  return obj if obj.kind_of?(Chainer::Variable)
  # TODO if obj is_backprop_required is true, set requires_grad = true
  self.new(obj, requires_grad: false)
end

Instance Method Details

#*(other) ⇒ Object



259
260
261
262
263
264
265
# File 'lib/chainer/variable.rb', line 259

def *(other)
  if other.instance_of?(Chainer::Variable)
    Functions::Math::Mul.new.apply([self, other])[0]
  else
    Functions::Math::MulConstant.new(other).apply([self])[0]
  end
end

#**(other) ⇒ Object



275
276
277
278
279
280
281
# File 'lib/chainer/variable.rb', line 275

def **(other)
  if other.instance_of?(Chainer::Variable)
    Functions::Math::PowVarVar.new.apply([self, other])[0]
  else
    Functions::Math::PowVarConst.new(other).apply([self])[0]
  end
end

#+(other) ⇒ Object



243
244
245
246
247
248
249
# File 'lib/chainer/variable.rb', line 243

def +(other)
  if other.instance_of?(Chainer::Variable)
    Functions::Math::Add.new.apply([self, other])[0]
  else
    Functions::Math::AddConstant.new(other).apply([self])[0]
  end
end

#-(other) ⇒ Object



251
252
253
254
255
256
257
# File 'lib/chainer/variable.rb', line 251

def -(other)
  if other.instance_of?(Chainer::Variable)
    Functions::Math::Sub.new.apply([self, other])[0]
  else
    Functions::Math::AddConstant.new(-other).apply([self])[0]
  end
end

#-@Object



239
240
241
# File 'lib/chainer/variable.rb', line 239

def -@
  Functions::Math::Neg.new.apply([self]).first
end

#/(other) ⇒ Object



267
268
269
270
271
272
273
# File 'lib/chainer/variable.rb', line 267

def /(other)
  if other.instance_of?(Chainer::Variable)
    Functions::Math::Div.new.apply([self, other])[0]
  else
    Functions::Math::MulConstant.new(1 / other).apply([self])[0]
  end
end

#_backward_main(retain_grad) ⇒ Object



144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# File 'lib/chainer/variable.rb', line 144

def _backward_main(retain_grad)
  node.check_old_style_gradient
  return if self.creator_node.nil?

  seen_set = Set.new
  grads = {}
  if self.data.size == 1 && self.grad_var.nil?
    self.grad = self.data.new_ones
  end
  grads[self.node] = self.grad_var

  funcs = [self.creator_node]
  seen_set.add(self.creator_node)

  while func = funcs.shift
    inputs = func.inputs
    target_input_indexes = inputs.each_with_index.map { |x, i| i if x.requires_grad }.compact
    next if target_input_indexes.empty?
    outputs = func.outputs.map(&:__getobj__)

    in_data = inputs.map(&:data)
    out_grad = outputs.map do |y|
      next nil if y.nil?
      next grads[y] unless grads[y].nil?
      y.grad_var
    end
    out_grad_data = out_grad.map { |g| g.nil? ? g : g.data }

    # Collect the current input gradients.
    #
    # When the same variable is passed to multiple input slots (e.g. an expression like +f(x, x)+),
    # it makes the gradient accumulation complicated since the back-propagated gradients w.r.t.
    # the first and second argument should be accumulated to the current gradient w.r.t. the same variable.
    # In this case, the current implementation passes the current gradient only to the first occurrence of the variable
    # in the input tuple and passes +nil+ to the rest of the occurrences.
    # For example, when the input variables are +(x, x)+,
    # the input gradient passed to the +backward_accumulate+ method is +(gx, nil)+ where +gx+ is the current gradient of ++x++.
    # See also the docstring of +FunctionNode.backward_accumulate+.
    target_inputs = target_input_indexes.map { |i| inputs[i] }
    in_grad = []
    target_input_indexes.each_with_index do |index_i, i|
      x = inputs[index_i]
      if target_inputs[0...i].include?(x)
        gx = nil
      elsif grads[x]
		gx = grads[x]
      elsif x.creator_node.nil?
        gx = x.grad_var
      else
        gx = nil
      end
      in_grad << gx
    end

    gxs = func.backward_accumulate(target_input_indexes, out_grad, in_grad)
    raise "Unmatched matries size: gxs.size(#{gxs.size}) != in_grad.size(#{in_grad.size})" unless gxs.size == in_grad.size

    unless retain_grad
      outputs.each do |y|
        unless y.nil? || y == @node
          grads[y] = nil
          y_var = y.get_variable
          y_var.grad_var = nil unless y_var.nil?
        end
      end
    end

    gxs.each_with_index do |gx, i|
      next if gx.nil?
      x = target_inputs[i]
      next unless x.requires_grad

      Utils::Variable.check_grad_type(func, x, gx.data)

      if target_inputs[0...i].include?(x)
        cur_gx = grads[x]
        grads[x] = cur_gx.nil? ? gx : gx + cur_gx
      else
        grads[x] = gx
      end

      x_var = x.get_variable
      x_var.grad_var = grads[x] if x_var

      if x.creator_node && !seen_set.include?(x.creator_node)
        funcs << x.creator_node
        seen_set.add(x.creator_node)
      end
    end

    funcs.sort_by! { |f| -f.rank }

  end
end

#backward(retain_grad: false, enable_double_backprop: true) ⇒ Object



137
138
139
140
141
142
# File 'lib/chainer/variable.rb', line 137

def backward(retain_grad: false, enable_double_backprop: true)
  old_enable_backprop = Chainer.configuration.enable_backprop
  Chainer.configuration.enable_backprop = enable_double_backprop
  _backward_main(retain_grad)
  Chainer.configuration.enable_backprop = old_enable_backprop
end

#cleargradObject

Clears the gradient array.



126
127
128
# File 'lib/chainer/variable.rb', line 126

def cleargrad
  @grad_var = nil
end

#coerce(other) ⇒ Object

when left side is Numeric value and right side is Chainer::Value, call this method.



288
289
290
291
# File 'lib/chainer/variable.rb', line 288

def coerce(other)
  other = self.data.class.new.fill(other) if other.kind_of?(Numeric)
  [Chainer::Variable.new(other, requires_grad: false), self]
end

#creatorObject

deprecated FunctionNode



60
61
62
# File 'lib/chainer/variable.rb', line 60

def creator
  @node.creator
end

#creator=(func) ⇒ Object



64
65
66
# File 'lib/chainer/variable.rb', line 64

def creator=(func)
  @node.creator = func
end

#creator_nodeObject



68
69
70
# File 'lib/chainer/variable.rb', line 68

def creator_node
  @node.creator_node
end

#creator_node=(func) ⇒ Object



72
73
74
# File 'lib/chainer/variable.rb', line 72

def creator_node=(func)
  @node.creator_node = func
end

#dtypeObject



106
107
108
# File 'lib/chainer/variable.rb', line 106

def dtype
  self.data.class
end

#grad_varObject



85
86
87
# File 'lib/chainer/variable.rb', line 85

def grad_var
  @grad_var
end

#grad_var=(g) ⇒ Object



89
90
91
92
# File 'lib/chainer/variable.rb', line 89

def grad_var=(g)
  Utils::Variable.check_grad_type(nil, self, g.data) unless g.nil?
  @grad_var = g
end

#labelObject



55
56
57
# File 'lib/chainer/variable.rb', line 55

def label
  @node.label
end

#nameObject



47
48
49
# File 'lib/chainer/variable.rb', line 47

def name
  return @node.name
end

#name=(n) ⇒ Object



51
52
53
# File 'lib/chainer/variable.rb', line 51

def name=(n)
  @node.name = n
end

#ndimObject



98
99
100
# File 'lib/chainer/variable.rb', line 98

def ndim
  self.data.ndim
end

#rankObject



110
111
112
# File 'lib/chainer/variable.rb', line 110

def rank
  @node.rank
end

#reshape(*shape) ⇒ Object



118
119
120
121
122
123
# File 'lib/chainer/variable.rb', line 118

def reshape(*shape)
  if shape.size == 1 && shape[0].kind_of?(::Aray)
    shape = shape[0]
  end
  Chainer::Functions::Array::Reshape.reshape(self, shape)
end

#retain_dataObject



283
284
285
# File 'lib/chainer/variable.rb', line 283

def retain_data
  @node.data = @data[0]
end

#set_creator_node(fnode) ⇒ Object

Notifies the variable that the given node is its creator.

Parameters:



133
134
135
# File 'lib/chainer/variable.rb', line 133

def set_creator_node(fnode)
  @node.set_creator_node(fnode)
end

#shapeObject



94
95
96
# File 'lib/chainer/variable.rb', line 94

def shape
  self.data.shape
end

#sizeObject



102
103
104
# File 'lib/chainer/variable.rb', line 102

def size
  self.data.size
end

#transposeObject



114
115
116
# File 'lib/chainer/variable.rb', line 114

def transpose
  Chainer::Functions::Array::Transpose.transpose(self)
end