Method: CIFAR10.load_train

Defined in:
lib/nn/cifar10.rb

.load_train(index) ⇒ Object



2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# File 'lib/nn/cifar10.rb', line 2

def self.load_train(index)
  if File.exist?("CIFAR-10-train#{index}.marshal")
    marshal = File.binread("CIFAR-10-train#{index}.marshal")
    return Marshal.load(marshal)
  end
  bin = File.binread("#{dir}/data_batch_#{index}.bin")
  datasets = bin.unpack("C*")
  x_train = []
  y_train = []
  loop do
    label = datasets.shift
    break unless label
    x_train << datasets.slice!(0, 3072)
    y_train << label
  end
  train = [x_train, y_train]
  File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
  train
end