Module: TorchVision::Utils

Defined in:
lib/torchvision/utils.rb

Class Method Summary collapse

Class Method Details

.image_from_array(array) ⇒ Object

private Ruby-specific method TODO use Numo when bridge available



89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# File 'lib/torchvision/utils.rb', line 89

def image_from_array(array)
  case array
  when Torch::Tensor
    # TODO support more dtypes
    raise "Type not supported yet: #{array.dtype}" unless array.dtype == :uint8

    array = array.contiguous unless array.contiguous?

    width, height = array.shape
    bands = array.shape[2] || 1
    data = FFI::Pointer.new(:uint8, array._data_ptr)
    data = data.slice(0, array.numel * array.element_size)

    Vips::Image.new_from_memory(data, width, height, bands, :uchar)
  when Numo::NArray
    # TODO support more types
    raise "Type not supported yet: #{array.class.name}" unless array.is_a?(Numo::UInt8)

    width, height = array.shape
    bands = array.shape[2] || 1
    data = array.to_binary

    Vips::Image.new_from_memory(data, width, height, bands, :uchar)
  else
    raise "Expected Torch::Tensor or Numo::NArray, not #{array.class.name}"
  end
end

.make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) ⇒ Object



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/torchvision/utils.rb', line 4

def make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
  unless Torch.tensor?(tensor) || (tensor.is_a?(Array) && tensor.all? { |t| Torch.tensor?(t) })
    raise ArgumentError, "tensor or list of tensors expected, got #{tensor.class.name}"
  end

  # if list of tensors, convert to a 4D mini-batch Tensor
  if tensor.is_a?(Array)
    tensor = Torch.stack(tensor, dim: 0)
  end

  if tensor.dim == 2 # single image H x W
    tensor = tensor.unsqueeze(0)
  end
  if tensor.dim == 3 # single image
    if tensor.size(0) == 1 # if single-channel, convert to 3-channel
      tensor = Torch.cat([tensor, tensor, tensor], 0)
    end
    tensor = tensor.unsqueeze(0)
  end

  if tensor.dim == 4 && tensor.size(1) == 1 # single-channel images
    tensor = Torch.cat([tensor, tensor, tensor], 1)
  end

  if normalize
    tensor = tensor.clone # avoid modifying tensor in-place
    if !range.nil? && !range.is_a?(Array)
      raise "range has to be an array (min, max) if specified. min and max are numbers"
    end

    norm_ip = lambda do |img, min, max|
      img.clamp!(min, max)
      img.add!(-min).div!(max - min + 1e-5)
    end

    norm_range = lambda do |t, range|
      if !range.nil?
        norm_ip.call(t, range[0], range[1])
      else
        norm_ip.call(t, t.min.to_f, t.max.to_f)
      end
    end

    if scale_each
      tensor.each do |t| # loop over mini-batch dimension
        norm_range.call(t, range)
      end
    else
      norm_range.call(tensor, range)
    end
  end

  if tensor.size(0) == 1
    return tensor.squeeze(0)
  end

  # make the mini-batch of images into a grid
  nmaps = tensor.size(0)
  xmaps = [nrow, nmaps].min
  ymaps = (nmaps.to_f / xmaps).ceil
  height, width = (tensor.size(2) + padding), (tensor.size(3) + padding)
  num_channels = tensor.size(1)
  grid = tensor.new_full([num_channels, height * ymaps + padding, width * xmaps + padding], pad_value)
  k = 0
  ymaps.times do |y|
    xmaps.times do |x|
      break if k >= nmaps
      grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy!(tensor[k])
      k += 1
    end
  end
  grid
end

.save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0) ⇒ Object



78
79
80
81
82
83
84
# File 'lib/torchvision/utils.rb', line 78

def save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
  grid = make_grid(tensor, nrow: nrow, padding: padding, pad_value: pad_value, normalize: normalize, range: range, scale_each: scale_each)
  # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
  ndarr = grid.mul(255).add!(0.5).clamp!(0, 255).permute(1, 2, 0).to("cpu", dtype: :uint8)
  im = image_from_array(ndarr)
  im.write_to_file(fp)
end