4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
# File 'lib/chainer/utils/math.rb', line 4
def self.tensordot(a, b, axes)
if axes.is_a?(Integer)
axes_a = (-axes...0).to_a
axes_b = (0...axes).to_a
else axes.is_a?(Array)
axes_a, axes_b = axes
end
axes_a = Array(axes_a)
axes_b = Array(axes_b)
na = axes_a.size
nb = axes_b.size
as = a.shape
nda = a.ndim
bs = b.shape
ndb = b.ndim
equal = true
if na != nb
equal = false
else
na.times do |k|
if as[axes_a[k]] != bs[axes_b[k]]
equal = false
break
end
if axes_a[k] < 0
axes_a[k] += nda
end
if axes_b[k] < 0
axes_b[k] += ndb
end
end
end
raise "shape-mismatch for sum" unless equal
notin = (0...nda).reject { |i| axes_a.include?(i) }
newaxes_a = notin + axes_a
n2 = 1
axes_a.each do |axis|
n2 *= as[axis]
end
tmp = a.shape.reduce(:*) / n2
newshape_a = [tmp, n2]
olda = notin.map { |axis| as[axis] }
notin = (0...ndb).reject { |i| axes_b.include?(i) }
newaxes_b = axes_b + notin
n2 = 1
axes_b.each do |axis|
n2 *= bs[axis]
end
tmp = b.shape.reduce(:*) / n2
newshape_b = [n2, tmp]
oldb = notin.map { |axis| bs[axis] }
at = a.transpose(*newaxes_a).reshape(*newshape_a)
bt = b.transpose(*newaxes_b).reshape(*newshape_b)
res = at.dot(bt)
return res.reshape(*(olda + oldb))
end
|