Class: Chainer::Variable
- Inherits:
-
Object
- Object
- Chainer::Variable
- Defined in:
- lib/chainer/variable.rb
Direct Known Subclasses
Instance Attribute Summary collapse
-
#data ⇒ Object
(also: #array)
Returns the value of attribute data.
-
#grad ⇒ Object
Returns the value of attribute grad.
-
#node ⇒ Object
Returns the value of attribute node.
-
#requires_grad ⇒ Object
Returns the value of attribute requires_grad.
Class Method Summary collapse
-
.as_variable(obj) ⇒ Chainer::Variable
Converts an array or a variable into
Chainer::Variable
.
Instance Method Summary collapse
- #*(other) ⇒ Object
- #**(other) ⇒ Object
- #+(other) ⇒ Object
- #-(other) ⇒ Object
- #-@ ⇒ Object
- #/(other) ⇒ Object
- #_backward_main(retain_grad) ⇒ Object
- #backward(retain_grad: false, enable_double_backprop: true) ⇒ Object
-
#cleargrad ⇒ Object
Clears the gradient array.
-
#coerce(other) ⇒ Object
when left side is Numeric value and right side is Chainer::Value, call this method.
-
#creator ⇒ Object
deprecated FunctionNode.
- #creator=(func) ⇒ Object
- #creator_node ⇒ Object
- #creator_node=(func) ⇒ Object
- #dtype ⇒ Object
- #grad_var ⇒ Object
- #grad_var=(g) ⇒ Object
-
#initialize(data = nil, name: nil, grad: nil, requires_grad: true) ⇒ Variable
constructor
A new instance of Variable.
- #label ⇒ Object
- #name ⇒ Object
- #name=(n) ⇒ Object
- #ndim ⇒ Object
- #rank ⇒ Object
- #reshape(*shape) ⇒ Object
- #retain_data ⇒ Object
-
#set_creator_node(fnode) ⇒ Object
Notifies the variable that the given node is its creator.
- #shape ⇒ Object
- #size ⇒ Object
- #transpose ⇒ Object
Constructor Details
#initialize(data = nil, name: nil, grad: nil, requires_grad: true) ⇒ Variable
Returns a new instance of Variable.
24 25 26 27 28 29 30 31 32 33 34 |
# File 'lib/chainer/variable.rb', line 24 def initialize(data=nil, name: nil, grad: nil, requires_grad: true) unless data.nil? || Chainer.array?(data) raise TypeError, "Numo::NArray or Cumo::NArray are expected." end @data = [data] @grad = grad @requires_grad = requires_grad @node = VariableNode.new(variable: self, name: name) @grad_var = grad.nil? ? nil : Chainer::Variable.new(grad) end |
Instance Attribute Details
#data ⇒ Object Also known as: array
Returns the value of attribute data.
3 4 5 |
# File 'lib/chainer/variable.rb', line 3 def data @data end |
#grad ⇒ Object
Returns the value of attribute grad.
3 4 5 |
# File 'lib/chainer/variable.rb', line 3 def grad @grad end |
#node ⇒ Object
Returns the value of attribute node.
3 4 5 |
# File 'lib/chainer/variable.rb', line 3 def node @node end |
#requires_grad ⇒ Object
Returns the value of attribute requires_grad.
3 4 5 |
# File 'lib/chainer/variable.rb', line 3 def requires_grad @requires_grad end |
Class Method Details
.as_variable(obj) ⇒ Chainer::Variable
Converts an array or a variable into Chainer::Variable
. This is a convenient function to get a Chainer::Variable
object transparently from a raw array or a variable. Note: that this function should only be used for type consistency (i.e. to enforce the return value of an API having type Chainer::Variable
). The Chianer::Variable.requires_grad
flag is kept as is; if obj
is a raw array, the newly created variable has requires_grad = false. In order to make a variable w.r.t. which you want to compute the gradient, you should use $Chainer::Variable$ directly.
18 19 20 21 22 |
# File 'lib/chainer/variable.rb', line 18 def self.as_variable(obj) return obj if obj.kind_of?(Chainer::Variable) # TODO if obj is_backprop_required is true, set requires_grad = true self.new(obj, requires_grad: false) end |
Instance Method Details
#*(other) ⇒ Object
259 260 261 262 263 264 265 |
# File 'lib/chainer/variable.rb', line 259 def *(other) if other.instance_of?(Chainer::Variable) Functions::Math::Mul.new.apply([self, other])[0] else Functions::Math::MulConstant.new(other).apply([self])[0] end end |
#**(other) ⇒ Object
275 276 277 278 279 280 281 |
# File 'lib/chainer/variable.rb', line 275 def **(other) if other.instance_of?(Chainer::Variable) Functions::Math::PowVarVar.new.apply([self, other])[0] else Functions::Math::PowVarConst.new(other).apply([self])[0] end end |
#+(other) ⇒ Object
243 244 245 246 247 248 249 |
# File 'lib/chainer/variable.rb', line 243 def +(other) if other.instance_of?(Chainer::Variable) Functions::Math::Add.new.apply([self, other])[0] else Functions::Math::AddConstant.new(other).apply([self])[0] end end |
#-(other) ⇒ Object
251 252 253 254 255 256 257 |
# File 'lib/chainer/variable.rb', line 251 def -(other) if other.instance_of?(Chainer::Variable) Functions::Math::Sub.new.apply([self, other])[0] else Functions::Math::AddConstant.new(-other).apply([self])[0] end end |
#-@ ⇒ Object
239 240 241 |
# File 'lib/chainer/variable.rb', line 239 def -@ Functions::Math::Neg.new.apply([self]).first end |
#/(other) ⇒ Object
267 268 269 270 271 272 273 |
# File 'lib/chainer/variable.rb', line 267 def /(other) if other.instance_of?(Chainer::Variable) Functions::Math::Div.new.apply([self, other])[0] else Functions::Math::MulConstant.new(1 / other).apply([self])[0] end end |
#_backward_main(retain_grad) ⇒ Object
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# File 'lib/chainer/variable.rb', line 144 def _backward_main(retain_grad) node.check_old_style_gradient return if self.creator_node.nil? seen_set = Set.new grads = {} if self.data.size == 1 && self.grad_var.nil? self.grad = self.data.new_ones end grads[self.node] = self.grad_var funcs = [self.creator_node] seen_set.add(self.creator_node) while func = funcs.shift inputs = func.inputs target_input_indexes = inputs.each_with_index.map { |x, i| i if x.requires_grad }.compact next if target_input_indexes.empty? outputs = func.outputs.map(&:__getobj__) in_data = inputs.map(&:data) out_grad = outputs.map do |y| next nil if y.nil? next grads[y] unless grads[y].nil? y.grad_var end out_grad_data = out_grad.map { |g| g.nil? ? g : g.data } # Collect the current input gradients. # # When the same variable is passed to multiple input slots (e.g. an expression like +f(x, x)+), # it makes the gradient accumulation complicated since the back-propagated gradients w.r.t. # the first and second argument should be accumulated to the current gradient w.r.t. the same variable. # In this case, the current implementation passes the current gradient only to the first occurrence of the variable # in the input tuple and passes +nil+ to the rest of the occurrences. # For example, when the input variables are +(x, x)+, # the input gradient passed to the +backward_accumulate+ method is +(gx, nil)+ where +gx+ is the current gradient of ++x++. # See also the docstring of +FunctionNode.backward_accumulate+. target_inputs = target_input_indexes.map { |i| inputs[i] } in_grad = [] target_input_indexes.each_with_index do |index_i, i| x = inputs[index_i] if target_inputs[0...i].include?(x) gx = nil elsif grads[x] gx = grads[x] elsif x.creator_node.nil? gx = x.grad_var else gx = nil end in_grad << gx end gxs = func.backward_accumulate(target_input_indexes, out_grad, in_grad) raise "Unmatched matries size: gxs.size(#{gxs.size}) != in_grad.size(#{in_grad.size})" unless gxs.size == in_grad.size unless retain_grad outputs.each do |y| unless y.nil? || y == @node grads[y] = nil y_var = y.get_variable y_var.grad_var = nil unless y_var.nil? end end end gxs.each_with_index do |gx, i| next if gx.nil? x = target_inputs[i] next unless x.requires_grad Utils::Variable.check_grad_type(func, x, gx.data) if target_inputs[0...i].include?(x) cur_gx = grads[x] grads[x] = cur_gx.nil? ? gx : gx + cur_gx else grads[x] = gx end x_var = x.get_variable x_var.grad_var = grads[x] if x_var if x.creator_node && !seen_set.include?(x.creator_node) funcs << x.creator_node seen_set.add(x.creator_node) end end funcs.sort_by! { |f| -f.rank } end end |
#backward(retain_grad: false, enable_double_backprop: true) ⇒ Object
137 138 139 140 141 142 |
# File 'lib/chainer/variable.rb', line 137 def backward(retain_grad: false, enable_double_backprop: true) old_enable_backprop = Chainer.configuration.enable_backprop Chainer.configuration.enable_backprop = enable_double_backprop _backward_main(retain_grad) Chainer.configuration.enable_backprop = old_enable_backprop end |
#cleargrad ⇒ Object
Clears the gradient array.
126 127 128 |
# File 'lib/chainer/variable.rb', line 126 def cleargrad @grad_var = nil end |
#coerce(other) ⇒ Object
when left side is Numeric value and right side is Chainer::Value, call this method.
288 289 290 291 |
# File 'lib/chainer/variable.rb', line 288 def coerce(other) other = self.data.class.new.fill(other) if other.kind_of?(Numeric) [Chainer::Variable.new(other, requires_grad: false), self] end |
#creator ⇒ Object
deprecated FunctionNode
60 61 62 |
# File 'lib/chainer/variable.rb', line 60 def creator @node.creator end |
#creator=(func) ⇒ Object
64 65 66 |
# File 'lib/chainer/variable.rb', line 64 def creator=(func) @node.creator = func end |
#creator_node ⇒ Object
68 69 70 |
# File 'lib/chainer/variable.rb', line 68 def creator_node @node.creator_node end |
#creator_node=(func) ⇒ Object
72 73 74 |
# File 'lib/chainer/variable.rb', line 72 def creator_node=(func) @node.creator_node = func end |
#dtype ⇒ Object
106 107 108 |
# File 'lib/chainer/variable.rb', line 106 def dtype self.data.class end |
#grad_var ⇒ Object
85 86 87 |
# File 'lib/chainer/variable.rb', line 85 def grad_var @grad_var end |
#grad_var=(g) ⇒ Object
89 90 91 92 |
# File 'lib/chainer/variable.rb', line 89 def grad_var=(g) Utils::Variable.check_grad_type(nil, self, g.data) unless g.nil? @grad_var = g end |
#label ⇒ Object
55 56 57 |
# File 'lib/chainer/variable.rb', line 55 def label @node.label end |
#name ⇒ Object
47 48 49 |
# File 'lib/chainer/variable.rb', line 47 def name return @node.name end |
#name=(n) ⇒ Object
51 52 53 |
# File 'lib/chainer/variable.rb', line 51 def name=(n) @node.name = n end |
#ndim ⇒ Object
98 99 100 |
# File 'lib/chainer/variable.rb', line 98 def ndim self.data.ndim end |
#rank ⇒ Object
110 111 112 |
# File 'lib/chainer/variable.rb', line 110 def rank @node.rank end |
#reshape(*shape) ⇒ Object
118 119 120 121 122 123 |
# File 'lib/chainer/variable.rb', line 118 def reshape(*shape) if shape.size == 1 && shape[0].kind_of?(::Aray) shape = shape[0] end Chainer::Functions::Array::Reshape.reshape(self, shape) end |
#retain_data ⇒ Object
283 284 285 |
# File 'lib/chainer/variable.rb', line 283 def retain_data @node.data = @data[0] end |
#set_creator_node(fnode) ⇒ Object
Notifies the variable that the given node is its creator.
133 134 135 |
# File 'lib/chainer/variable.rb', line 133 def set_creator_node(fnode) @node.set_creator_node(fnode) end |
#shape ⇒ Object
94 95 96 |
# File 'lib/chainer/variable.rb', line 94 def shape self.data.shape end |
#size ⇒ Object
102 103 104 |
# File 'lib/chainer/variable.rb', line 102 def size self.data.size end |