Module: Chainer::Utils::Array
- Defined in:
- lib/chainer/utils/array.rb
Class Method Summary collapse
- .broadcast_to(x, shape) ⇒ Object
- .force_array(x, dtype = nil) ⇒ Object
- .ndindex(shape) ⇒ Object
- .rollaxis(y, axis, start: 0) ⇒ Object
- .take(x, indices, axis: nil) ⇒ Object
Class Method Details
.broadcast_to(x, shape) ⇒ Object
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# File 'lib/chainer/utils/array.rb', line 80 def self.broadcast_to(x, shape) if x.shape.size > shape.size raise TypeError, "Shape of data mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})" end tile_shape = [] if x.shape.size > 0 shape[-x.shape.size..-1].each_with_index do |s, i| if x.shape[i] == 1 tile_shape << s elsif x.shape[i] == s tile_shape << 1 else raise TypeError, "Shape of data mismatch\n#{x.shape} != #{shape}" end end else tile_shape = shape end x.tile(*shape[0...-x.shape.size], *tile_shape) end |
.force_array(x, dtype = nil) ⇒ Object
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
# File 'lib/chainer/utils/array.rb', line 4 def self.force_array(x, dtype=nil) if x.is_a? Integer or x.is_a? Float if dtype.nil? xm = Chainer::Device.default.xm xm::NArray.cast(x) else dtype.cast(x.dup) end else if dtype.nil? x else dtype.cast(x) end end end |
.ndindex(shape) ⇒ Object
21 22 23 24 25 26 27 |
# File 'lib/chainer/utils/array.rb', line 21 def self.ndindex(shape) shape.reduce(&:*).times.map do |i| shape.size.times.reduce([]) do |ndidx, j| ndidx << (i / shape.drop(j + 1).reduce(1, &:*)) % shape[j] end end end |
.rollaxis(y, axis, start: 0) ⇒ Object
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# File 'lib/chainer/utils/array.rb', line 50 def self.rollaxis(y, axis, start: 0) n = y.ndim # normalize axis axis = axis < 0 ? n + axis : axis if axis >= n raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}" end if start < 0 start += n end unless 0 <= start && start < n + 1 raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in" end if axis < start start -= 1 end if axis == start return y end axes = (0...n).to_a axes.delete_at(axis) axes.insert(start <= axes.size ? start : -1, axis) y.transpose(*axes) end |
.take(x, indices, axis: nil) ⇒ Object
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/chainer/utils/array.rb', line 29 def self.take(x, indices, axis: nil) if axis dimensional_indices = ::Array.new(x.shape.size, true) indices_narray = Numo::Int32.cast(indices) if indices_narray.shape.size > 1 y = x.class.zeros(*indices_narray.shape, *x.shape.drop(axis + 1)) self.ndindex(indices_narray.shape).each do |ndidx| dimensional_indices[axis] = indices_narray[*ndidx] y[*ndidx, *::Array.new(x.shape.size - axis - 1, true)] = x[*dimensional_indices] end return y else dimensional_indices[axis] = indices end x[*dimensional_indices] else x[indices] end end |