15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
# File 'lib/torch/optim/adadelta.rb', line 15
def step(closure = nil)
loss = nil
if closure
loss = closure.call
end
@param_groups.each do |group|
group[:params].each do |p|
next unless p.grad
grad = p.grad.data
if grad.sparse?
raise Error, "Adadelta does not support sparse gradients"
end
state = @state[p]
if state.size == 0
state[:step] = 0
state[:square_avg] = Torch.zeros_like(p.data)
state[:acc_delta] = Torch.zeros_like(p.data)
end
square_avg, acc_delta = state[:square_avg], state[:acc_delta]
rho, eps = group[:rho], group[:eps]
state[:step] += 1
if group[:weight_decay] != 0
grad = grad.add(p.data, alpha: group[:weight_decay])
end
square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
std = square_avg.add(eps).sqrt!
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
p.data.add!(delta, alpha: -group[:lr])
acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
end
end
loss
end
|