Class: DBMLP

Inherits:
Object
  • Object
show all
Includes:
Network, TestResults, TestResultsParser, Training
Defined in:
lib/db_mlp.rb

Class Method Summary collapse

Instance Method Summary collapse

Methods included from TestResultsParser

included

Constructor Details

#initialize(db_path, options = {}) ⇒ DBMLP

Returns a new instance of DBMLP.



25
26
27
28
29
30
31
32
33
34
# File 'lib/db_mlp.rb', line 25

def initialize(db_path, options={})
  @input_size = options[:inputs]
  @hidden_layers = options[:hidden_layers]
  @output_nodes = options[:output_nodes]
  @verbose = options[:verbose]
  @validate_every = options[:validate_every] || 200
  @db_path = db_path
      
  @network = setup_network
end

Class Method Details

.load(db_path) ⇒ Object



14
15
16
17
18
19
20
21
22
# File 'lib/db_mlp.rb', line 14

def load(db_path)
  data = ""
  File.open(db_path) do |f|
    while line = f.gets
      data << line
    end
  end
  Marshal.load(data)
end

Instance Method Details

#feed_forward(input) ⇒ Object



36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/db_mlp.rb', line 36

def feed_forward(input)
  @network.each_with_index do |layer, layer_index|
    # We go through each layer taking the previous layers outputs and using them
    # as the next layers inputs
    layer.each do |neuron|
      if layer_index == 0
        neuron.fire(input)
      else
        input = @network[layer_index-1].map {|x| x.last_output}
        neuron.fire(input)
      end
    end
  end
  last_outputs
end

#inspectObject



59
60
61
# File 'lib/db_mlp.rb', line 59

def inspect
  @network
end

#saveObject



63
64
65
66
67
# File 'lib/db_mlp.rb', line 63

def save
  File.open(@db_path, 'w+') do |f|
    f.write(Marshal.dump(self))
  end
end

#train(training, testing, validations, n = 3000, report_path = nil) ⇒ Object



52
53
54
55
56
57
# File 'lib/db_mlp.rb', line 52

def train(training, testing, validations, n=3000, report_path=nil)
  train_and_cross_validate(training, validations, n)
  # Create a test report if they want one
  create_test_report(testing, report_path) unless report_path.nil?
  save
end