Class: TorchVision::Datasets::MNIST

Inherits:
VisionDataset
  • Object
show all
Defined in:
lib/torchvision/datasets/mnist.rb

Direct Known Subclasses

FashionMNIST, KMNIST

Instance Attribute Summary

Attributes inherited from VisionDataset

#data, #targets

Instance Method Summary collapse

Constructor Details

#initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ MNIST



5
6
7
8
9
10
11
12
13
14
15
16
17
# File 'lib/torchvision/datasets/mnist.rb', line 5

def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
  super(root, transform: transform, target_transform: target_transform)
  @train = train

  self.download if download

  if !check_exists
    raise Error, "Dataset not found. You can use download: true to download it"
  end

  data_file = @train ? training_file : test_file
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
end

Instance Method Details

#[](index) ⇒ Object



23
24
25
26
27
28
29
30
31
32
33
# File 'lib/torchvision/datasets/mnist.rb', line 23

def [](index)
  img, target = @data[index], @targets[index].item

  img = Utils.image_from_array(img)

  img = @transform.call(img) if @transform

  target = @target_transform.call(target) if @target_transform

  [img, target]
end

#check_existsObject



43
44
45
46
# File 'lib/torchvision/datasets/mnist.rb', line 43

def check_exists
  File.exist?(File.join(processed_folder, training_file)) &&
    File.exist?(File.join(processed_folder, test_file))
end

#downloadObject



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/torchvision/datasets/mnist.rb', line 48

def download
  return if check_exists

  FileUtils.mkdir_p(raw_folder)
  FileUtils.mkdir_p(processed_folder)

  resources.each do |resource|
    success = false
    mirrors.each do |mirror|
      begin
        url = "#{mirror}#{resource[:filename]}"
        download_file(url, download_root: raw_folder, filename: resource[:filename], sha256: resource[:sha256])
        success = true
        break
      rescue Net::HTTPFatalError, Net::HTTPClientException => e
        puts "Failed to download (trying next): #{e.message}"
      end
    end
    raise Error, "Error downloading #{resource[:filename]}" unless success
  end

  puts "Processing..."

  training_set = [
    unpack_mnist("train-images-idx3-ubyte", 16, [60000, 28, 28]),
    unpack_mnist("train-labels-idx1-ubyte", 8, [60000])
  ]
  test_set = [
    unpack_mnist("t10k-images-idx3-ubyte", 16, [10000, 28, 28]),
    unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
  ]

  Torch.save(training_set, File.join(processed_folder, training_file))
  Torch.save(test_set, File.join(processed_folder, test_file))

  puts "Done!"
end

#processed_folderObject



39
40
41
# File 'lib/torchvision/datasets/mnist.rb', line 39

def processed_folder
  File.join(@root, self.class.name.split("::").last, "processed")
end

#raw_folderObject



35
36
37
# File 'lib/torchvision/datasets/mnist.rb', line 35

def raw_folder
  File.join(@root, self.class.name.split("::").last, "raw")
end

#sizeObject



19
20
21
# File 'lib/torchvision/datasets/mnist.rb', line 19

def size
  @data.size(0)
end