Class: Chainer::Functions::Loss::MeanSquaredError

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/loss/mean_squared_error.rb

Overview

Mean squared error (a.k.a. Euclidean loss) function.

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #initialize, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

This class inherits a constructor from Chainer::FunctionNode

Class Method Details

.mean_squared_error(x0, x1) ⇒ Chainer::Variable

Mean squared error function.

This function computes mean squared error between two variables. The mean is taken over the minibatch. Note that the error is not scaled by 1/2.

Parameters:

Returns:

  • (Chainer::Variable)

    A variable holding an array representing the mean squared error of two inputs.



15
16
17
# File 'lib/chainer/functions/loss/mean_squared_error.rb', line 15

def self.mean_squared_error(x0, x1)
  self.new.apply([x0, x1]).first
end

Instance Method Details

#backward(indexes, gy) ⇒ Object



25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/chainer/functions/loss/mean_squared_error.rb', line 25

def backward(indexes, gy)
  x0, x1 = get_retained_inputs
  diff = x0 - x1
  gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape)
  gx0 = gy0 * diff * (2.0 / diff.size)

  ret = []
  if indexes.include?(0)
    ret << gx0
  end
  if indexes.include?(1)
    ret << -gx0
  end
  ret
end

#forward(inputs) ⇒ Object



19
20
21
22
23
# File 'lib/chainer/functions/loss/mean_squared_error.rb', line 19

def forward(inputs)
  retain_inputs([0, 1])
  diff = (inputs[0] - inputs[1]).flatten.dup
  [diff.class.cast(diff.dot(diff) / diff.size)]
end