Class: Chainer::UpdateRule

Inherits:
Object
  • Object
show all
Defined in:
lib/chainer/optimizer.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(parent_hyperparam:) ⇒ UpdateRule

Returns a new instance of UpdateRule.



57
58
59
60
61
62
63
# File 'lib/chainer/optimizer.rb', line 57

def initialize(parent_hyperparam:)
  @hooks = {}  
  @state = nil
  @enabled = true
  @hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
  @t = 0
end

Instance Attribute Details

#stateObject (readonly)

Returns the value of attribute state.



55
56
57
# File 'lib/chainer/optimizer.rb', line 55

def state
  @state
end

Instance Method Details

#init_state(param) ⇒ Object

Raises:

  • (NotImplementedError)


95
96
97
# File 'lib/chainer/optimizer.rb', line 95

def init_state(param)
  raise NotImplementedError
end

#serialize(serializer) ⇒ Object

Serializes the update rule state. Be careful that this method only saves/loads the state of the update rule. The parameters of the target link is not saved/loaded by this method, and so you need to serialize the target link separately if you want to fully recover the training state including parameters.

Parameters:



107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# File 'lib/chainer/optimizer.rb', line 107

def serialize(serializer)
  if @state.nil?
    if serializer.is_a?(Chainer::Deserializer)
      # try to initialize the state to retrieve state entries
      @state = {}
      self_copy = self.dup
      # TODO(sonots): pass device from outside
      xm = Chainer::Device.default.xm
      arr = xm::SFloat.new(1)
      self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
      @state.keys.each do |key|
        @state[key] = serializer.(key.to_s, nil)
      end
    end
  else
    @state.each do |key, val|
      @state[key] = serializer.(key.to_s, val)
    end
  end                                                                                 
end

#update(param) ⇒ Object



65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/chainer/optimizer.rb', line 65

def update(param)
  return unless @enabled

  @t += 1
  unless param.data.nil?
    prepare(param)
  end
  @hooks.values.each do |hook|
    hook.call(param)
  end
  update_core(param)
end

#update_core(param) ⇒ Object



78
79
80
81
82
83
84
85
# File 'lib/chainer/optimizer.rb', line 78

def update_core(param)
  xm = Chainer.get_array_module(param)
  if xm == Cumo
    update_core_gpu(param)
  else
    update_core_cpu(param)
  end
end

#update_core_cpuObject

Raises:

  • (NotImplementedError)


87
88
89
# File 'lib/chainer/optimizer.rb', line 87

def update_core_cpu
  raise NotImplementedError
end

#update_core_gpuObject

Raises:

  • (NotImplementedError)


91
92
93
# File 'lib/chainer/optimizer.rb', line 91

def update_core_gpu
  raise NotImplementedError
end