Class: Transformers::BatchFeature
- Inherits:
-
Object
- Object
- Transformers::BatchFeature
- Defined in:
- lib/transformers/feature_extraction_utils.rb
Instance Method Summary collapse
- #[](item) ⇒ Object
- #_get_is_as_tensor_fns(tensor_type: nil) ⇒ Object
- #convert_to_tensors(tensor_type: nil) ⇒ Object
-
#initialize(data:, tensor_type:) ⇒ BatchFeature
constructor
A new instance of BatchFeature.
- #items ⇒ Object
- #keys ⇒ Object
- #to(*args, **kwargs) ⇒ Object
- #to_h ⇒ Object (also: #to_hash)
- #values ⇒ Object
Constructor Details
#initialize(data:, tensor_type:) ⇒ BatchFeature
Returns a new instance of BatchFeature.
17 18 19 20 |
# File 'lib/transformers/feature_extraction_utils.rb', line 17 def initialize(data:, tensor_type:) @data = data convert_to_tensors(tensor_type: tensor_type) end |
Instance Method Details
#[](item) ⇒ Object
27 28 29 |
# File 'lib/transformers/feature_extraction_utils.rb', line 27 def [](item) @data[item] end |
#_get_is_as_tensor_fns(tensor_type: nil) ⇒ Object
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# File 'lib/transformers/feature_extraction_utils.rb', line 43 def _get_is_as_tensor_fns(tensor_type: nil) if tensor_type.nil? return [nil, nil] end as_tensor = lambda do |value| if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray) value = Numo::NArray.cast(value) end Torch.tensor(value) end is_tensor = Torch.method(:tensor?) [is_tensor, as_tensor] end |
#convert_to_tensors(tensor_type: nil) ⇒ Object
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# File 'lib/transformers/feature_extraction_utils.rb', line 60 def convert_to_tensors(tensor_type: nil) if tensor_type.nil? return self end is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type) # Do the tensor conversion in batch items.each do |key, value| begin if !is_tensor.(value) tensor = as_tensor.(value) @data[key] = tensor end rescue if key == :overflowing_values raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths." end raise ArgumentError, "Unable to create tensor, you should probably activate padding " + "with 'padding: true' to have batched tensors with the same length." end end self end |
#items ⇒ Object
39 40 41 |
# File 'lib/transformers/feature_extraction_utils.rb', line 39 def items @data end |
#keys ⇒ Object
31 32 33 |
# File 'lib/transformers/feature_extraction_utils.rb', line 31 def keys @data.keys end |
#to(*args, **kwargs) ⇒ Object
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# File 'lib/transformers/feature_extraction_utils.rb', line 88 def to(*args, **kwargs) new_data = {} device = kwargs[:device] # Check if the args are a device or a dtype if device.nil? && args.length > 0 raise Todo end # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` items.each do |k, v| # check if v is a floating point if Torch.floating_point?(v) # cast and send to device new_data[k] = v.to(*args, **kwargs) elsif !device.nil? new_data[k] = v.to(device) else new_data[k] = v end end @data = new_data self end |
#to_h ⇒ Object Also known as: to_hash
22 23 24 |
# File 'lib/transformers/feature_extraction_utils.rb', line 22 def to_h @data end |
#values ⇒ Object
35 36 37 |
# File 'lib/transformers/feature_extraction_utils.rb', line 35 def values @data.values end |