Module: Chainer::Utils::Array

Defined in:
lib/chainer/utils/array.rb

Class Method Summary collapse

Class Method Details

.broadcast_to(x, shape) ⇒ Object



76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# File 'lib/chainer/utils/array.rb', line 76

def self.broadcast_to(x, shape)
  if x.shape.size > shape.size
     raise TypeError, "Shape of data  mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})"
  end

  tile_shape = []
  if x.shape.size > 0
    shape[-x.shape.size..-1].each_with_index do |s, i|
      if  x.shape[i] == 1
        tile_shape << s
      elsif x.shape[i] == s
        tile_shape << 1
      else
        raise TypeError, "Shape of data  mismatch\n#{x.shape} != #{shape}"
      end
    end
  else
    tile_shape = shape
  end

  x.tile(*shape[0...-x.shape.size], *tile_shape)
end

.force_array(x, dtype = nil) ⇒ Object



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# File 'lib/chainer/utils/array.rb', line 4

def self.force_array(x, dtype=nil)
  if x.is_a? Integer or x.is_a? Float
    if dtype.nil?
      xm = Chainer::Device.default.xm
      xm::NArray.cast(x)
    else
      dtype.cast(x.dup)
    end
  else
    if dtype.nil?
      x
    else
      dtype.cast(x)
    end
  end
end

.make_indecies_with_axis(shape, indices, axis, values = []) ⇒ Object



28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/chainer/utils/array.rb', line 28

def self.make_indecies_with_axis(shape, indices, axis, values = [])
  target_axis = values.size
  if shape.size == values.size
    values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
      (sum + x) * ndim
    end
  else
    enum = (axis == target_axis) ? indices : (0...shape[target_axis])
    if enum.is_a?(Integer)
      make_indecies_with_axis(shape, indices, axis, values + [indices])
    else
      enum.map do |x|
        make_indecies_with_axis(shape, indices, axis, values + [x])
      end
    end
  end
end

.rollaxis(y, axis, start: 0) ⇒ Object



46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# File 'lib/chainer/utils/array.rb', line 46

def self.rollaxis(y, axis, start: 0)
  n = y.ndim
  # normalize axis
  axis = axis < 0 ? n + axis : axis
  if axis >= n
    raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}"
  end

  if start < 0
    start += n
  end

  unless 0 <= start && start < n + 1
    raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in"
  end

  if axis < start
    start -= 1
  end

  if axis == start
    return y
  end

  axes = (0...n).to_a
  axes.delete_at(axis)
  axes.insert(start <= axes.size ? start : -1, axis)
  y.transpose(*axes)
end

.take(x, indices, axis: nil) ⇒ Object



21
22
23
24
25
26
# File 'lib/chainer/utils/array.rb', line 21

def self.take(x, indices, axis: nil)
  if axis
    indices = make_indecies_with_axis(x.shape, indices, axis)
  end
  x[indices]
end