Module: MNIST

Defined in:
lib/nn/mnist.rb

Class Method Summary collapse

Class Method Details

.categorical(y_data) ⇒ Object


28
29
30
31
32
33
34
# File 'lib/nn/mnist.rb', line 28

def self.categorical(y_data)
  y_data = y_data.map do |label|
    classes = Array.new(10, 0)
    classes[label] = 1
    classes
  end
end

.load(images_file_name, labels_file_name) ⇒ Object


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/nn/mnist.rb', line 38

def self.load(images_file_name, labels_file_name)
  images = []
  labels = nil
  Zlib::GzipReader.open(images_file_name) do |f|
    magic, n_images = f.read(8).unpack("N2")
    n_rows, n_cols = f.read(8).unpack("N2")
    n_images.times do
      images << f.read(n_rows * n_cols).unpack("C*")
    end
  end
  Zlib::GzipReader.open(labels_file_name) do |f|
    magic, n_labels = f.read(8).unpack("N2")
    labels = f.read(n_labels).unpack("C*")
  end
  [images, labels]
end

.load_testObject


16
17
18
19
20
21
22
23
24
25
26
# File 'lib/nn/mnist.rb', line 16

def self.load_test
  if File.exist?("mnist/test.marshal")
    marshal = File.binread("mnist/test.marshal")
    Marshal.load(marshal)
  else
    x_test, y_test = load("mnist/t10k-images-idx3-ubyte.gz", "mnist/t10k-labels-idx1-ubyte.gz")
    marshal = Marshal.dump([x_test, y_test])
    File.binwrite("mnist/test.marshal", marshal)
    [x_test, y_test]
  end
end

.load_trainObject


4
5
6
7
8
9
10
11
12
13
14
# File 'lib/nn/mnist.rb', line 4

def self.load_train
  if File.exist?("mnist/train.marshal")
    marshal = File.binread("mnist/train.marshal")
    Marshal.load(marshal)
  else
    x_train, y_train = load("mnist/train-images-idx3-ubyte.gz", "mnist/train-labels-idx1-ubyte.gz")
    marshal = Marshal.dump([x_train, y_train])
    File.binwrite("mnist/train.marshal", marshal)
    [x_train, y_train]
  end
end