14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
# File 'lib/chainer/datasets/cifar.rb', line 14
def self.get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default)
train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table
test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table
train_data = train_table[:pixels]
test_data = test_table[:pixels]
if n_classes == 10
train_labels = train_table[:label]
test_labels = test_table[:label]
else
train_labels = train_table[:fine_label]
test_labels = test_table[:fine_label]
end
xm = device.xm
[
preprocess_cifar(xm::UInt8[*train_data], xm::UInt8[*train_labels], with_label, ndim, scale),
preprocess_cifar(xm::UInt8[*test_data], xm::UInt8[*test_labels], with_label, ndim, scale)
]
end
|