Module: MNIST
- Defined in:
- lib/nn/mnist.rb
Class Method Summary collapse
- .categorical(y_data) ⇒ Object
- .load(images_file_name, labels_file_name) ⇒ Object
- .load_test ⇒ Object
- .load_train ⇒ Object
Class Method Details
.categorical(y_data) ⇒ Object
28 29 30 31 32 33 34 |
# File 'lib/nn/mnist.rb', line 28 def self.categorical(y_data) y_data = y_data.map do |label| classes = Array.new(10, 0) classes[label] = 1 classes end end |
.load(images_file_name, labels_file_name) ⇒ Object
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/nn/mnist.rb', line 38 def self.load(images_file_name, labels_file_name) images = [] labels = nil Zlib::GzipReader.open(images_file_name) do |f| magic, n_images = f.read(8).unpack("N2") n_rows, n_cols = f.read(8).unpack("N2") n_images.times do images << f.read(n_rows * n_cols).unpack("C*") end end Zlib::GzipReader.open(labels_file_name) do |f| magic, n_labels = f.read(8).unpack("N2") labels = f.read(n_labels).unpack("C*") end [images, labels] end |
.load_test ⇒ Object
16 17 18 19 20 21 22 23 24 25 26 |
# File 'lib/nn/mnist.rb', line 16 def self.load_test if File.exist?("mnist/test.marshal") marshal = File.binread("mnist/test.marshal") Marshal.load(marshal) else x_test, y_test = load("mnist/t10k-images-idx3-ubyte.gz", "mnist/t10k-labels-idx1-ubyte.gz") marshal = Marshal.dump([x_test, y_test]) File.binwrite("mnist/test.marshal", marshal) [x_test, y_test] end end |
.load_train ⇒ Object
4 5 6 7 8 9 10 11 12 13 14 |
# File 'lib/nn/mnist.rb', line 4 def self.load_train if File.exist?("mnist/train.marshal") marshal = File.binread("mnist/train.marshal") Marshal.load(marshal) else x_train, y_train = load("mnist/train-images-idx3-ubyte.gz", "mnist/train-labels-idx1-ubyte.gz") marshal = Marshal.dump([x_train, y_train]) File.binwrite("mnist/train.marshal", marshal) [x_train, y_train] end end |