Class: Chainer::Functions::Array::Cast

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/array/cast.rb

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(type) ⇒ Cast

Returns a new instance of Cast.



25
26
27
# File 'lib/chainer/functions/array/cast.rb', line 25

def initialize(type)
    @type = type
end

Class Method Details

.cast(x, type) ⇒ Chainer::Variable

Cast an input variable to a given type.

example > x = Numo::UInt8.new(3, 5).seq > x.class # => Numo::UInt8 > y = Chainer::Functions::Array::Cast.cast(x, Numo::DFloat) > y.dtype # => Numo::DFloat

Parameters:

  • x (Chainer::Variable or Numo::Narray)

    x : Input variable to be casted.

  • type (Numo::Narray class)

    type : data class to cast

Returns:



18
19
20
21
22
23
# File 'lib/chainer/functions/array/cast.rb', line 18

def self.cast(x, type)
  if (Chainer.array?(x) && x.class == type) || (x.is_a?(Chainer::Variable) && x.dtype == type)
    return Chainer::Variable.as_variable(x)
  end
  self.new(type).apply([x]).first
end

Instance Method Details

#backward(indexes, g) ⇒ Object



34
35
36
# File 'lib/chainer/functions/array/cast.rb', line 34

def backward(indexes, g)
  [Cast.cast(g.first, @in_type)]
end

#forward(x) ⇒ Object



29
30
31
32
# File 'lib/chainer/functions/array/cast.rb', line 29

def forward(x)
  @in_type = x.first.class
  [x.first.cast_to(@type)]
end