Class: Newral::Training::GradientDescent

Inherits:
Object
  • Object
show all
Defined in:
lib/newral/training/gradient_descent.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(input: [], output: [], iterations: 10**5, klass: Newral::Functions::Polynomial, klass_args: {}, start_function: nil) ⇒ GradientDescent

Returns a new instance of GradientDescent.



5
6
7
8
9
10
11
12
# File 'lib/newral/training/gradient_descent.rb', line 5

def initialize( input: [], output: [], iterations:10**5, klass: Newral::Functions::Polynomial, klass_args: {}, start_function: nil   )
  @input = input
  @output = output
  @iterations = iterations
  @klass = klass
  @klass_args = klass_args
  @best_function = start_function
end

Instance Attribute Details

#best_errorObject (readonly)

Returns the value of attribute best_error.



4
5
6
# File 'lib/newral/training/gradient_descent.rb', line 4

def best_error
  @best_error
end

#best_functionObject (readonly)

Returns the value of attribute best_function.



4
5
6
# File 'lib/newral/training/gradient_descent.rb', line 4

def best_function
  @best_function
end

#inputObject (readonly)

Returns the value of attribute input.



4
5
6
# File 'lib/newral/training/gradient_descent.rb', line 4

def input
  @input
end

Instance Method Details

#process(start_fresh: false, learning_rate: 0.01, step: 0.01) ⇒ Object



15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# File 'lib/newral/training/gradient_descent.rb', line 15

def process( start_fresh: false, learning_rate:0.01, step:0.01 )
  @best_function = ( start_fresh ? @klass.create_random( @klass_args ) : @best_function || @klass.create_random( @klass_args )).dup
  @best_error = @best_function.calculate_error( input: @input, output: @output ) 
  optimized_error =  0
  @iterations.times do 
      function = @best_function.dup.move_with_gradient( input: @input, output: @output, learning_rate: learning_rate, step: step )
      optimized_error = function.calculate_error( input: @input, output: @output ) 
      if optimized_error >= @best_error 
        step = step/10 if step > 10**-8 # # slow down
        learning_rate = learning_rate / 10 if learning_rate > 10**-8
      else 
        @best_function = function
        best_error = optimized_error
      end 
  end 
  @best_error = @best_function.calculate_error( input: @input, output: @output )
  @best_function
end