Module: Chainer::Datasets::CIFAR
- Defined in:
- lib/chainer/datasets/cifar.rb
Class Method Summary collapse
- .get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default) ⇒ Object
- .get_cifar10(with_label: true, ndim: 3, scale: 1.0) ⇒ Object
- .get_cifar100(with_label: true, ndim: 3, scale: 1.0) ⇒ Object
- .preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default) ⇒ Object
Class Method Details
permalink .get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default) ⇒ Object
[View source]
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# File 'lib/chainer/datasets/cifar.rb', line 14 def self.get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default) train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table train_data = train_table[:pixels] test_data = test_table[:pixels] if n_classes == 10 train_labels = train_table[:label] test_labels = test_table[:label] else train_labels = train_table[:fine_label] test_labels = test_table[:fine_label] end xm = device.xm [ preprocess_cifar(xm::UInt8[*train_data], xm::UInt8[*train_labels], with_label, ndim, scale), preprocess_cifar(xm::UInt8[*test_data], xm::UInt8[*test_labels], with_label, ndim, scale) ] end |
permalink .get_cifar10(with_label: true, ndim: 3, scale: 1.0) ⇒ Object
[View source]
6 7 8 |
# File 'lib/chainer/datasets/cifar.rb', line 6 def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0) get_cifar(10, with_label, ndim, scale) end |
permalink .get_cifar100(with_label: true, ndim: 3, scale: 1.0) ⇒ Object
[View source]
10 11 12 |
# File 'lib/chainer/datasets/cifar.rb', line 10 def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0) get_cifar(100, with_label, ndim, scale) end |
permalink .preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default) ⇒ Object
[View source]
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/chainer/datasets/cifar.rb', line 35 def self.preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default) if ndim == 1 images = images.reshape(images.shape[0], 3072) elsif ndim == 3 images = images.reshape(images.shape[0], 3, 32, 32) else raise 'invalid ndim for CIFAR dataset' end xm = device.xm images = images.cast_to(xm::SFloat) images *= scale / 255.0 if withlabel labels = labels.cast_to(xm::Int32) TupleDataset.new(images, labels) else images end end |