Module: Chainer::Utils::Array
- Defined in:
- lib/chainer/utils/array.rb
Class Method Summary collapse
- .broadcast_to(x, shape) ⇒ Object
- .force_array(x, dtype = nil) ⇒ Object
- .make_indecies_with_axis(shape, indices, axis, values = []) ⇒ Object
- .rollaxis(y, axis, start: 0) ⇒ Object
- .take(x, indices, axis: nil) ⇒ Object
Class Method Details
.broadcast_to(x, shape) ⇒ Object
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# File 'lib/chainer/utils/array.rb', line 76 def self.broadcast_to(x, shape) if x.shape.size > shape.size raise TypeError, "Shape of data mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})" end tile_shape = [] if x.shape.size > 0 shape[-x.shape.size..-1].each_with_index do |s, i| if x.shape[i] == 1 tile_shape << s elsif x.shape[i] == s tile_shape << 1 else raise TypeError, "Shape of data mismatch\n#{x.shape} != #{shape}" end end else tile_shape = shape end x.tile(*shape[0...-x.shape.size], *tile_shape) end |
.force_array(x, dtype = nil) ⇒ Object
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
# File 'lib/chainer/utils/array.rb', line 4 def self.force_array(x, dtype=nil) if x.is_a? Integer or x.is_a? Float if dtype.nil? xm = Chainer::Device.default.xm xm::NArray.cast(x) else dtype.cast(x.dup) end else if dtype.nil? x else dtype.cast(x) end end end |
.make_indecies_with_axis(shape, indices, axis, values = []) ⇒ Object
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# File 'lib/chainer/utils/array.rb', line 28 def self.make_indecies_with_axis(shape, indices, axis, values = []) target_axis = values.size if shape.size == values.size values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)| (sum + x) * ndim end else enum = (axis == target_axis) ? indices : (0...shape[target_axis]) if enum.is_a?(Integer) make_indecies_with_axis(shape, indices, axis, values + [indices]) else enum.map do |x| make_indecies_with_axis(shape, indices, axis, values + [x]) end end end end |
.rollaxis(y, axis, start: 0) ⇒ Object
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# File 'lib/chainer/utils/array.rb', line 46 def self.rollaxis(y, axis, start: 0) n = y.ndim # normalize axis axis = axis < 0 ? n + axis : axis if axis >= n raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}" end if start < 0 start += n end unless 0 <= start && start < n + 1 raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in" end if axis < start start -= 1 end if axis == start return y end axes = (0...n).to_a axes.delete_at(axis) axes.insert(start <= axes.size ? start : -1, axis) y.transpose(*axes) end |
.take(x, indices, axis: nil) ⇒ Object
21 22 23 24 25 26 |
# File 'lib/chainer/utils/array.rb', line 21 def self.take(x, indices, axis: nil) if axis indices = make_indecies_with_axis(x.shape, indices, axis) end x[indices] end |