2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
# File 'lib/nn/cifar10.rb', line 2
def self.load_train(index)
if File.exist?("CIFAR-10-train#{index}.marshal")
marshal = File.binread("CIFAR-10-train#{index}.marshal")
return Marshal.load(marshal)
end
bin = File.binread("#{dir}/data_batch_#{index}.bin")
datasets = bin.unpack("C*")
x_train = []
y_train = []
loop do
label = datasets.shift
break unless label
x_train << datasets.slice!(0, 3072)
y_train << label
end
train = [x_train, y_train]
File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
train
end
|