Module: Chainer::Utils::Variable

Defined in:
lib/chainer/utils/variable.rb

Class Method Summary collapse

Class Method Details

.check_grad_type(func, x, gx) ⇒ Object



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# File 'lib/chainer/utils/variable.rb', line 4

def self.check_grad_type(func, x, gx)
  if x.data.nil? || gx.nil?
    return
  end

  unless gx.is_a?(x.data.class.superclass)
    raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
  end

  unless gx.class == x.data.class
    raise TypeError, "Dtype(Class) of data and grad mismatch\n#{x.data.class} != #{gx.class}"
  end

  unless gx.shape == x.data.shape
    raise TypeError, "Shape of data and grad mismatch\n#{x.data.shape} != #{gx.shape}"
  end
end