Class: TorchVision::Datasets::CIFAR10
- Inherits:
-
VisionDataset
- Object
- Torch::Utils::Data::Dataset
- VisionDataset
- TorchVision::Datasets::CIFAR10
- Defined in:
- lib/torchvision/datasets/cifar10.rb
Direct Known Subclasses
Instance Attribute Summary
Attributes inherited from VisionDataset
Instance Method Summary collapse
- #[](index) ⇒ Object
- #_check_integrity ⇒ Object
- #download ⇒ Object
- #initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ CIFAR10 constructor
- #size ⇒ Object
Constructor Details
#initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ CIFAR10
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
# File 'lib/torchvision/datasets/cifar10.rb', line 6 def initialize(root, train: true, download: false, transform: nil, target_transform: nil) super(root, transform: transform, target_transform: target_transform) @train = train self.download if download if !_check_integrity raise Error, "Dataset not found or corrupted. You can use download=True to download it" end downloaded_list = @train ? train_list : test_list @data = String.new @targets = String.new downloaded_list.each do |file| file_path = File.join(@root, base_folder, file[:filename]) File.open(file_path, "rb") do |f| while !f.eof? f.read(1) if multiple_labels? @targets << f.read(1) @data << f.read(3072) end end end @targets = @targets.unpack("C*") # TODO switch i to -1 when Numo supports it @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32) @data = @data.transpose(0, 2, 3, 1) end |
Instance Method Details
#[](index) ⇒ Object
42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/torchvision/datasets/cifar10.rb', line 42 def [](index) # TODO remove trues when Numo supports it img, target = @data[index, true, true, true], @targets[index] img = Utils.image_from_array(img) img = @transform.call(img) if @transform target = @target_transform.call(target) if @target_transform [img, target] end |
#_check_integrity ⇒ Object
55 56 57 58 59 60 61 62 |
# File 'lib/torchvision/datasets/cifar10.rb', line 55 def _check_integrity root = @root (train_list + test_list).each do |fentry| fpath = File.join(root, base_folder, fentry[:filename]) return false unless check_integrity(fpath, fentry[:sha256]) end true end |
#download ⇒ Object
64 65 66 67 68 69 70 71 72 73 74 75 76 |
# File 'lib/torchvision/datasets/cifar10.rb', line 64 def download if _check_integrity puts "Files already downloaded and verified" return end download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256) path = File.join(@root, filename) File.open(path, "rb") do |io| Gem::Package.new("").extract_tar_gz(io, @root) end end |
#size ⇒ Object
38 39 40 |
# File 'lib/torchvision/datasets/cifar10.rb', line 38 def size @data.shape[0] end |