Class: Chainer::Functions::Loss::MeanSquaredError
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Loss::MeanSquaredError
- Defined in:
- lib/chainer/functions/loss/mean_squared_error.rb
Overview
Mean squared error (a.k.a. Euclidean loss) function.
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Class Method Summary collapse
-
.mean_squared_error(x0, x1) ⇒ Chainer::Variable
Mean squared error function.
Instance Method Summary collapse
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #initialize, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
This class inherits a constructor from Chainer::FunctionNode
Class Method Details
.mean_squared_error(x0, x1) ⇒ Chainer::Variable
Mean squared error function.
This function computes mean squared error between two variables. The mean is taken over the minibatch. Note that the error is not scaled by 1/2.
15 16 17 |
# File 'lib/chainer/functions/loss/mean_squared_error.rb', line 15 def self.mean_squared_error(x0, x1) self.new.apply([x0, x1]).first end |
Instance Method Details
#backward(indexes, gy) ⇒ Object
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# File 'lib/chainer/functions/loss/mean_squared_error.rb', line 25 def backward(indexes, gy) x0, x1 = get_retained_inputs diff = x0 - x1 gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape) gx0 = gy0 * diff * (2.0 / diff.size) ret = [] if indexes.include?(0) ret << gx0 end if indexes.include?(1) ret << -gx0 end ret end |
#forward(inputs) ⇒ Object
19 20 21 22 23 |
# File 'lib/chainer/functions/loss/mean_squared_error.rb', line 19 def forward(inputs) retain_inputs([0, 1]) diff = (inputs[0] - inputs[1]).flatten.dup [diff.class.cast(diff.dot(diff) / diff.size)] end |