Module: Chainer::Utils::Array

Defined in:
lib/chainer/utils/array.rb

Class Method Summary collapse

Class Method Details

.broadcast_to(x, shape) ⇒ Object



80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# File 'lib/chainer/utils/array.rb', line 80

def self.broadcast_to(x, shape)
  if x.shape.size > shape.size
     raise TypeError, "Shape of data  mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})"
  end

  tile_shape = []
  if x.shape.size > 0
    shape[-x.shape.size..-1].each_with_index do |s, i|
      if  x.shape[i] == 1
        tile_shape << s
      elsif x.shape[i] == s
        tile_shape << 1
      else
        raise TypeError, "Shape of data  mismatch\n#{x.shape} != #{shape}"
      end
    end
  else
    tile_shape = shape
  end

  x.tile(*shape[0...-x.shape.size], *tile_shape)
end

.force_array(x, dtype = nil) ⇒ Object



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# File 'lib/chainer/utils/array.rb', line 4

def self.force_array(x, dtype=nil)
  if x.is_a? Integer or x.is_a? Float
    if dtype.nil?
      xm = Chainer::Device.default.xm
      xm::NArray.cast(x)
    else
      dtype.cast(x.dup)
    end
  else
    if dtype.nil?
      x
    else
      dtype.cast(x)
    end
  end
end

.ndindex(shape) ⇒ Object



21
22
23
24
25
26
27
# File 'lib/chainer/utils/array.rb', line 21

def self.ndindex(shape)
  shape.reduce(&:*).times.map do |i|
    shape.size.times.reduce([]) do |ndidx, j|
      ndidx << (i / shape.drop(j + 1).reduce(1, &:*)) % shape[j]
    end
  end
end

.rollaxis(y, axis, start: 0) ⇒ Object



50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/chainer/utils/array.rb', line 50

def self.rollaxis(y, axis, start: 0)
  n = y.ndim
  # normalize axis
  axis = axis < 0 ? n + axis : axis
  if axis >= n
    raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}"
  end

  if start < 0
    start += n
  end

  unless 0 <= start && start < n + 1
    raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in"
  end

  if axis < start
    start -= 1
  end

  if axis == start
    return y
  end

  axes = (0...n).to_a
  axes.delete_at(axis)
  axes.insert(start <= axes.size ? start : -1, axis)
  y.transpose(*axes)
end

.take(x, indices, axis: nil) ⇒ Object



29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# File 'lib/chainer/utils/array.rb', line 29

def self.take(x, indices, axis: nil)
  if axis
    dimensional_indices = ::Array.new(x.shape.size, true)

    indices_narray = Numo::Int32.cast(indices)
    if indices_narray.shape.size > 1
      y = x.class.zeros(*indices_narray.shape, *x.shape.drop(axis + 1))
      self.ndindex(indices_narray.shape).each do |ndidx|
        dimensional_indices[axis] = indices_narray[*ndidx]
        y[*ndidx, *::Array.new(x.shape.size - axis - 1, true)] = x[*dimensional_indices]
      end
      return y
    else
      dimensional_indices[axis] = indices
    end
    x[*dimensional_indices]
  else
    x[indices]
  end
end