14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
|
# File 'lib/ruby_brain/dataset/mnist/data.rb', line 14
def data
train_images_path = Dir.pwd + '/train-images-idx3-ubyte.gz'
train_labels_path = Dir.pwd + '/train-labels-idx1-ubyte.gz'
unless File.exist?(train_images_path)
puts 'downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...'
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', train_images_path)
end
unless File.exist?(train_labels_path)
puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', train_labels_path)
end
train_images = Mnist.load_images(train_images_path)
train_labels = Mnist.load_labels(train_labels_path)
input_training_set = train_images[2].map do |image|
image.unpack('C*').map {|e| e / 255.0}
end
output_training_set = train_labels.map do |label|
one_hot_vector = Array.new(10, 0)
one_hot_vector[label] = 1
one_hot_vector
end
{input: input_training_set, output: output_training_set}
end
|