Class: TorchVision::Transforms::Functional

Inherits:
Object
  • Object
show all
Defined in:
lib/torchvision/transforms/functional.rb

Class Method Summary collapse

Class Method Details

.center_crop(img, output_size) ⇒ Object



121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# File 'lib/torchvision/transforms/functional.rb', line 121

def center_crop(img, output_size)
  if output_size.is_a?(Integer)
    output_size = [output_size.to_i, output_size.to_i]
  elsif output_size.is_a?(Array) && output_size.length == 1
    output_size = [output_size[0], output_size[0]]
  end

  image_width, image_height = image_size(img)
  crop_height, crop_width = output_size

  if crop_width > image_width || crop_height > image_height
    padding_ltrb = [
      crop_width > image_width ? (crop_width - image_width).div(2) : 0,
      crop_height > image_height ? (crop_height - image_height).div(2) : 0,
      crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
      crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
    ]
    # TODO
    img = pad(img, padding_ltrb, fill: 0)
    image_width, image_height = image_size(img)
    if crop_width == image_width && crop_height == image_height
      return img
    end
  end

  crop_top = ((image_height - crop_height) / 2.0).round
  crop_left = ((image_width - crop_width) / 2.0).round
  crop(img, crop_top, crop_left, crop_height, crop_width)
end

.crop(img, top, left, height, width) ⇒ Object



111
112
113
114
115
116
117
118
119
# File 'lib/torchvision/transforms/functional.rb', line 111

def crop(img, top, left, height, width)
  if img.is_a?(Torch::Tensor)
    assert_image_tensor(img)
    indexes = [true] * (img.dim - 2)
    img[*indexes, top...(top + height), left...(left + width)]
  else
    img.crop(left, top, width, height)
  end
end

.hflip(img) ⇒ Object



93
94
95
96
97
98
99
100
# File 'lib/torchvision/transforms/functional.rb', line 93

def hflip(img)
  if img.is_a?(Torch::Tensor)
    assert_image_tensor(img)
    img.flip(-1)
  else
    img.flip(:horizontal)
  end
end

.normalize(tensor, mean, std, inplace: false) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/torchvision/transforms/functional.rb', line 5

def normalize(tensor, mean, std, inplace: false)
  unless Torch.tensor?(tensor)
    raise ArgumentError, "tensor should be a torch tensor. Got #{tensor.class.name}"
  end

  if tensor.ndimension != 3
    raise ArgumentError, "Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = #{tensor.size}"
  end

  tensor = tensor.clone unless inplace

  dtype = tensor.dtype
  # TODO Torch.as_tensor
  mean = Torch.tensor(mean, dtype: dtype, device: tensor.device)
  std = Torch.tensor(std, dtype: dtype, device: tensor.device)

  # TODO
  if std.to_a.any? { |v| v == 0 }
    raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
  end
  if mean.ndim == 1
    mean = mean[0...mean.size(0), nil, nil]
  end
  if std.ndim == 1
    std = std[0...std.size(0), nil, nil]
  end
  tensor.sub!(mean).div!(std)
  tensor
end

.resize(img, size) ⇒ Object



35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/torchvision/transforms/functional.rb', line 35

def resize(img, size)
  raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)

  if size.is_a?(Integer)
    w, h = img.size
    if (w <= h && w == size) || (h <= w && h == size)
      return img
    end
    if w < h
      ow = size
      oh = (size * h / w).to_i
      img.thumbnail_image(ow, height: oh)
    else
      oh = size
      ow = (size * w / h).to_i
      img.thumbnail_image(ow, height: oh)
    end
  else
    img.thumbnail_image(size[0], height: size[1], size: :force)
  end
end

.resized_crop(img, top, left, height, width, size) ⇒ Object

TODO interpolation



152
153
154
155
156
# File 'lib/torchvision/transforms/functional.rb', line 152

def resized_crop(img, top, left, height, width, size)
  img = crop(img, top, left, height, width)
  img = resize(img, size) #, interpolation)
  img
end

.to_tensor(pic) ⇒ Object

TODO improve



58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# File 'lib/torchvision/transforms/functional.rb', line 58

def to_tensor(pic)
  if !pic.is_a?(Numo::NArray) && !pic.is_a?(Vips::Image)
    raise ArgumentError, "pic should be Vips::Image or Numo::NArray. Got #{pic.class.name}"
  end

  if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
    raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
  end

  if pic.is_a?(Numo::NArray)
    if pic.ndim == 2
      pic = pic.reshape(*pic.shape, 1)
    end

    img = Torch.from_numo(pic.transpose(2, 0, 1))
    if img.dtype == :uint8
      return img.float.div(255)
    else
      return img
    end
  end

  case pic.format
  when :uchar
    img = Torch::ByteTensor.new(Torch::ByteStorage.from_buffer(pic.write_to_memory))
  else
    raise Error, "Format not supported yet: #{pic.format}"
  end

  img = img.view(pic.height, pic.width, pic.bands)
  # put it from HWC to CHW format
  img = img.permute([2, 0, 1]).contiguous
  img.float.div(255)
end

.vflip(img) ⇒ Object



102
103
104
105
106
107
108
109
# File 'lib/torchvision/transforms/functional.rb', line 102

def vflip(img)
  if img.is_a?(Torch::Tensor)
    assert_image_tensor(img)
    img.flip(-2)
  else
    img.flip(:vertical)
  end
end