4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
# File 'lib/chainer/utils/variable.rb', line 4
def self.check_grad_type(func, x, gx)
if x.data.nil? || gx.nil?
return
end
unless gx.is_a?(x.data.class.superclass)
raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
end
unless gx.class == x.data.class
raise TypeError, "Dtype(Class) of data and grad mismatch\n#{x.data.class} != #{gx.class}"
end
unless gx.shape == x.data.shape
raise TypeError, "Shape of data and grad mismatch\n#{x.data.shape} != #{gx.shape}"
end
end
|