Module: TorchVision::Utils
- Defined in:
- lib/torchvision/utils.rb
Class Method Summary collapse
-
.image_from_array(array) ⇒ Object
private Ruby-specific method TODO use Numo when bridge available.
- .make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) ⇒ Object
- .save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) ⇒ Object
Class Method Details
.image_from_array(array) ⇒ Object
private Ruby-specific method TODO use Numo when bridge available
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# File 'lib/torchvision/utils.rb', line 89 def image_from_array(array) case array when Torch::Tensor # TODO support more dtypes raise "Type not supported yet: #{array.dtype}" unless array.dtype == :uint8 array = array.contiguous unless array.contiguous? width, height = array.shape bands = array.shape[2] || 1 data = FFI::Pointer.new(:uint8, array._data_ptr) data = data.slice(0, array.numel * array.element_size) Vips::Image.new_from_memory(data, width, height, bands, :uchar) when Numo::NArray # TODO support more types raise "Type not supported yet: #{array.class.name}" unless array.is_a?(Numo::UInt8) width, height = array.shape bands = array.shape[2] || 1 data = array.to_binary Vips::Image.new_from_memory(data, width, height, bands, :uchar) else raise "Expected Torch::Tensor or Numo::NArray, not #{array.class.name}" end end |
.make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) ⇒ Object
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# File 'lib/torchvision/utils.rb', line 4 def make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) unless Torch.tensor?(tensor) || (tensor.is_a?(Array) && tensor.all? { |t| Torch.tensor?(t) }) raise ArgumentError, "tensor or list of tensors expected, got #{tensor.class.name}" end # if list of tensors, convert to a 4D mini-batch Tensor if tensor.is_a?(Array) tensor = Torch.stack(tensor, dim: 0) end if tensor.dim == 2 # single image H x W tensor = tensor.unsqueeze(0) end if tensor.dim == 3 # single image if tensor.size(0) == 1 # if single-channel, convert to 3-channel tensor = Torch.cat([tensor, tensor, tensor], 0) end tensor = tensor.unsqueeze(0) end if tensor.dim == 4 && tensor.size(1) == 1 # single-channel images tensor = Torch.cat([tensor, tensor, tensor], 1) end if normalize tensor = tensor.clone # avoid modifying tensor in-place if !range.nil? && !range.is_a?(Array) raise "range has to be an array (min, max) if specified. min and max are numbers" end norm_ip = lambda do |img, min, max| img.clamp!(min, max) img.add!(-min).div!(max - min + 1e-5) end norm_range = lambda do |t, range| if !range.nil? norm_ip.call(t, range[0], range[1]) else norm_ip.call(t, t.min.to_f, t.max.to_f) end end if scale_each tensor.each do |t| # loop over mini-batch dimension norm_range.call(t, range) end else norm_range.call(tensor, range) end end if tensor.size(0) == 1 return tensor.squeeze(0) end # make the mini-batch of images into a grid nmaps = tensor.size(0) xmaps = [nrow, nmaps].min ymaps = (nmaps.to_f / xmaps).ceil height, width = (tensor.size(2) + padding), (tensor.size(3) + padding) num_channels = tensor.size(1) grid = tensor.new_full([num_channels, height * ymaps + padding, width * xmaps + padding], pad_value) k = 0 ymaps.times do |y| xmaps.times do |x| break if k >= nmaps grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy!(tensor[k]) k += 1 end end grid end |
.save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) ⇒ Object
78 79 80 81 82 83 84 |
# File 'lib/torchvision/utils.rb', line 78 def save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) grid = make_grid(tensor, nrow: nrow, padding: padding, pad_value: pad_value, normalize: normalize, range: range, scale_each: scale_each) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add!(0.5).clamp!(0, 255).permute(1, 2, 0).to("cpu", dtype: :uint8) im = image_from_array(ndarr) im.write_to_file(fp) end |