Module: Transformers::ImageTransforms

Defined in:
lib/transformers/image_transforms.rb

Class Method Summary collapse

Class Method Details

._rescale_for_pil_conversion(image) ⇒ Object



199
200
201
202
203
204
205
206
# File 'lib/transformers/image_transforms.rb', line 199

def self._rescale_for_pil_conversion(image)
  if image.is_a?(Numo::UInt8)
    do_rescale = false
  else
    raise Todo
  end
  do_rescale
end

.normalize(image, mean, std, data_format: nil, input_data_format: nil) ⇒ Object



116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# File 'lib/transformers/image_transforms.rb', line 116

def self.normalize(
  image,
  mean,
  std,
  data_format: nil,
  input_data_format: nil
)
  if !image.is_a?(Numo::NArray)
    raise ArgumentError, "image must be a numpy array"
  end

  if input_data_format.nil?
    input_data_format = infer_channel_dimension_format(image)
  end

  channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format)
  num_channels = image.shape[channel_axis]

  # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
  # We preserve the original dtype if it is a float type to prevent upcasting float16.
  if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat)
    image = image.cast_to(Numo::SFloat)
  end

  if mean.is_a?(Enumerable)
    if mean.length != num_channels
      raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}"
    end
  else
    mean = [mean] * num_channels
  end
  mean = Numo::DFloat.cast(mean)

  if std.is_a?(Enumerable)
    if std.length != num_channels
      raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}"
    end
  else
    std = [std] * num_channels
  end
  std = Numo::DFloat.cast(std)

  if input_data_format == ChannelDimension::LAST
    image = (image - mean) / std
  else
    image = ((image.transpose - mean) / std).transpose
  end

  image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image
  image
end

.rescale(image, scale, data_format: nil, dtype: Numo::SFloat, input_data_format: nil) ⇒ Object



46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# File 'lib/transformers/image_transforms.rb', line 46

def self.rescale(
  image,
  scale,
  data_format: nil,
  dtype: Numo::SFloat,
  input_data_format: nil
)
  if !image.is_a?(Numo::NArray)
    raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
  end

  rescaled_image = image * scale
  if !data_format.nil?
    rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
  end

  rescaled_image = rescaled_image.cast_to(dtype)

  rescaled_image
end

.resize(image, size, resample: nil, reducing_gap: nil, data_format: nil, return_numpy: true, input_data_format: nil) ⇒ Object



67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# File 'lib/transformers/image_transforms.rb', line 67

def self.resize(
  image,
  size,
  resample: nil,
  reducing_gap: nil,
  data_format: nil,
  return_numpy: true,
  input_data_format: nil
)
  resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR

  if size.length != 2
    raise ArgumentError, "size must have 2 elements"
  end

  # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
  # The resized image from PIL will always have channels last, so find the input format first.
  if input_data_format.nil?
    input_data_format = ImageUtils.infer_channel_dimension_format(image)
  end
  data_format = data_format.nil? ? input_data_format : data_format

  # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
  # the pillow library to resize the image and then convert back to numpy
  do_rescale = false
  if !image.is_a?(Vips::Image)
    do_rescale = _rescale_for_pil_conversion(image)
    image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format)
  end
  height, width = size
  # TODO support resample
  resized_image = image.thumbnail_image(width, height: height, size: :force)

  if return_numpy
    resized_image = ImageUtils.to_numo_array(resized_image)
    # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
    # so we need to add it back if necessary.
    resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image
    # The image is always in channels last format after converting from a PIL image
    resized_image = to_channel_dimension_format(
      resized_image, data_format, input_channel_dim: ChannelDimension::LAST
    )
    # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
    # rescale it back to the original range.
    resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image
  end
  resized_image
end

.to_channel_dimension_format(image, channel_dim, input_channel_dim: nil) ⇒ Object



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/transformers/image_transforms.rb', line 17

def self.to_channel_dimension_format(
  image,
  channel_dim,
  input_channel_dim: nil
)
  if !image.is_a?(Numo::NArray)
    raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
  end

  if input_channel_dim.nil?
    input_channel_dim = infer_channel_dimension_format(image)
  end

  target_channel_dim = ChannelDimension.new(channel_dim).to_s
  if input_channel_dim == target_channel_dim
    return image
  end

  if target_channel_dim == ChannelDimension::FIRST
    image = image.transpose(2, 0, 1)
  elsif target_channel_dim == ChannelDimension::LAST
    image = image.transpose(1, 2, 0)
  else
    raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}"
  end

  image
end

.to_pil_image(image, do_rescale: nil, input_data_format: nil) ⇒ Object



168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# File 'lib/transformers/image_transforms.rb', line 168

def self.to_pil_image(
  image,
  do_rescale: nil,
  input_data_format: nil
)
  if image.is_a?(Vips::Image)
    return image
  end

  # Convert all tensors to numo arrays before converting to Vips image
  if !image.is_a?(Numo::NArray)
    raise ArgumentError, "Input image type not supported: #{image.class.name}"
  end

  # If the channel has been moved to first dim, we put it back at the end.
  image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format)

  # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
  # image = image.shape[-1] == 1 ? image.squeeze(-1) : image

  # Rescale the image to be between 0 and 255 if needed.
  do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale

  if do_rescale
    image = rescale(image, 255)
  end

  image = image.cast_to(Numo::UInt8)
  Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar)
end