Class: Chainer::UpdateRule

Inherits:
Object
  • Object
show all
Defined in:
lib/chainer/optimizer.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(parent_hyperparam:) ⇒ UpdateRule

Returns a new instance of UpdateRule.



38
39
40
41
42
43
44
# File 'lib/chainer/optimizer.rb', line 38

def initialize(parent_hyperparam:)
  @hooks = {}  
  @state = nil
  @enabled = true
  @hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
  @t = 0
end

Instance Attribute Details

#stateObject (readonly)

Returns the value of attribute state.



36
37
38
# File 'lib/chainer/optimizer.rb', line 36

def state
  @state
end

Instance Method Details

#init_state(param) ⇒ Object

Raises:

  • (NotImplementedError)


66
67
68
# File 'lib/chainer/optimizer.rb', line 66

def init_state(param)
  raise NotImplementedError
end

#serialize(serializer) ⇒ Object

Serializes the update rule state. Be careful that this method only saves/loads the state of the update rule. The parameters of the target link is not saved/loaded by this method, and so you need to serialize the target link separately if you want to fully recover the training state including parameters.

Parameters:



78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# File 'lib/chainer/optimizer.rb', line 78

def serialize(serializer)
  if @state.nil?
    if serializer.is_a?(Chainer::Deserializer)
      # try to initialize the state to retrieve state entries
      @state = {}
      self_copy = self.dup
      arr = Numo::SFloat.new(1)
      self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
      @state.keys.each do |key|
        @state[key] = serializer.(key.to_s, nil)
      end
    end
  else
    @state.each do |key, val|
      @state[key] = serializer.(key.to_s, val)
    end
  end                                                                                 
end

#update(param) ⇒ Object



46
47
48
49
50
51
52
53
54
55
# File 'lib/chainer/optimizer.rb', line 46

def update(param)
  return unless @enabled

  @t += 1
  prepare(param)
  @hooks.values.each do |hook|
    hook.call(param)
  end
  update_core(param)
end

#update_core(param) ⇒ Object



57
58
59
60
# File 'lib/chainer/optimizer.rb', line 57

def update_core(param)
  # TODO: support GPU
  update_core_cpu(param)
end

#update_core_cpuObject

Raises:

  • (NotImplementedError)


62
63
64
# File 'lib/chainer/optimizer.rb', line 62

def update_core_cpu
  raise NotImplementedError
end