Class: Transformers::BatchFeature

Inherits:
Object
  • Object
show all
Defined in:
lib/transformers/feature_extraction_utils.rb

Instance Method Summary collapse

Constructor Details

#initialize(data:, tensor_type:) ⇒ BatchFeature

Returns a new instance of BatchFeature.



17
18
19
20
# File 'lib/transformers/feature_extraction_utils.rb', line 17

def initialize(data:, tensor_type:)
  @data = data
  convert_to_tensors(tensor_type: tensor_type)
end

Instance Method Details

#[](item) ⇒ Object



27
28
29
# File 'lib/transformers/feature_extraction_utils.rb', line 27

def [](item)
  @data[item]
end

#_get_is_as_tensor_fns(tensor_type: nil) ⇒ Object



43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# File 'lib/transformers/feature_extraction_utils.rb', line 43

def _get_is_as_tensor_fns(tensor_type: nil)
  if tensor_type.nil?
    return [nil, nil]
  end

  as_tensor = lambda do |value|
    if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
      value = Numo::NArray.cast(value)
    end
    Torch.tensor(value)
  end

  is_tensor = Torch.method(:tensor?)

  [is_tensor, as_tensor]
end

#convert_to_tensors(tensor_type: nil) ⇒ Object



60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# File 'lib/transformers/feature_extraction_utils.rb', line 60

def convert_to_tensors(tensor_type: nil)
  if tensor_type.nil?
    return self
  end

  is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type)

  # Do the tensor conversion in batch
  items.each do |key, value|
    begin
      if !is_tensor.(value)
        tensor = as_tensor.(value)

        @data[key] = tensor
      end
    rescue
      if key == :overflowing_values
        raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths."
      end
      raise ArgumentError,
        "Unable to create tensor, you should probably activate padding " +
        "with 'padding: true' to have batched tensors with the same length."
    end
  end

  self
end

#itemsObject



39
40
41
# File 'lib/transformers/feature_extraction_utils.rb', line 39

def items
  @data
end

#keysObject



31
32
33
# File 'lib/transformers/feature_extraction_utils.rb', line 31

def keys
  @data.keys
end

#to(*args, **kwargs) ⇒ Object



88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# File 'lib/transformers/feature_extraction_utils.rb', line 88

def to(*args, **kwargs)
  new_data = {}
  device = kwargs[:device]
  # Check if the args are a device or a dtype
  if device.nil? && args.length > 0
    raise Todo
  end
  # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
  items.each do |k, v|
    # check if v is a floating point
    if Torch.floating_point?(v)
      # cast and send to device
      new_data[k] = v.to(*args, **kwargs)
    elsif !device.nil?
      new_data[k] = v.to(device)
    else
      new_data[k] = v
    end
  end
  @data = new_data
  self
end

#to_hObject Also known as: to_hash



22
23
24
# File 'lib/transformers/feature_extraction_utils.rb', line 22

def to_h
  @data
end

#valuesObject



35
36
37
# File 'lib/transformers/feature_extraction_utils.rb', line 35

def values
  @data.values
end