Class: Torch::Tensor
- Inherits:
-
Object
- Object
- Torch::Tensor
- Includes:
- Comparable, Enumerable, Inspector
- Defined in:
- lib/torch/tensor.rb
Direct Known Subclasses
Constant Summary
Constants included from Inspector
Class Method Summary collapse
Instance Method Summary collapse
-
#<=>(other) ⇒ Object
TODO better compare?.
-
#[](*indexes) ⇒ Object
based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html.
-
#[]=(*indexes, value) ⇒ Object
based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html.
- #coerce(other) ⇒ Object
- #cpu ⇒ Object
- #cuda ⇒ Object
- #dtype ⇒ Object
- #dup ⇒ Object
- #each ⇒ Object
-
#imag ⇒ Object
not a method in native_functions.yaml attribute in Python rather than method.
- #item ⇒ Object
- #layout ⇒ Object
-
#length ⇒ Object
mirror Python len().
-
#new ⇒ Object
unsure if this is correct.
-
#numo ⇒ Object
TODO read directly from memory.
-
#random!(*args) ⇒ Object
parser can’t handle overlap, so need to handle manually.
-
#real ⇒ Object
not a method in native_functions.yaml attribute in Python rather than method.
- #requires_grad=(requires_grad) ⇒ Object
- #size(dim = nil) ⇒ Object
- #stride(dim = nil) ⇒ Object
- #to(device = nil, dtype: nil, non_blocking: false, copy: false) ⇒ Object
-
#to_a ⇒ Object
TODO make more performant.
- #to_f ⇒ Object
- #to_i ⇒ Object
- #to_s ⇒ Object
- #type(dtype) ⇒ Object
Methods included from Inspector
Class Method Details
.new(*args) ⇒ Object
29 30 31 |
# File 'lib/torch/tensor.rb', line 29 def self.new(*args) FloatTensor.new(*args) end |
Instance Method Details
#<=>(other) ⇒ Object
TODO better compare?
161 162 163 |
# File 'lib/torch/tensor.rb', line 161 def <=>(other) item <=> other end |
#[](*indexes) ⇒ Object
based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html
167 168 169 170 |
# File 'lib/torch/tensor.rb', line 167 def [](*indexes) indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v } _index(indexes) end |
#[]=(*indexes, value) ⇒ Object
based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html
174 175 176 177 178 179 |
# File 'lib/torch/tensor.rb', line 174 def []=(*indexes, value) raise ArgumentError, "Tensor does not support deleting items" if value.nil? indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v } value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor) _index_put_custom(indexes, value) end |
#coerce(other) ⇒ Object
205 206 207 208 209 210 211 |
# File 'lib/torch/tensor.rb', line 205 def coerce(other) if other.is_a?(Numeric) [Torch.tensor(other), self] else raise TypeError, "#{self.class} can't be coerced into #{other.class}" end end |
#cpu ⇒ Object
84 85 86 |
# File 'lib/torch/tensor.rb', line 84 def cpu to("cpu") end |
#cuda ⇒ Object
88 89 90 |
# File 'lib/torch/tensor.rb', line 88 def cuda to("cuda") end |
#dtype ⇒ Object
33 34 35 36 37 |
# File 'lib/torch/tensor.rb', line 33 def dtype dtype = ENUM_TO_DTYPE[_dtype] raise Error, "Unknown type: #{_dtype}" unless dtype dtype end |
#dup ⇒ Object
187 188 189 190 191 |
# File 'lib/torch/tensor.rb', line 187 def dup Torch.no_grad do clone end end |
#each ⇒ Object
47 48 49 50 51 52 53 |
# File 'lib/torch/tensor.rb', line 47 def each return enum_for(:each) unless block_given? size(0).times do |i| yield self[i] end end |
#imag ⇒ Object
not a method in native_functions.yaml attribute in Python rather than method
195 196 197 |
# File 'lib/torch/tensor.rb', line 195 def imag Torch.imag(self) end |
#item ⇒ Object
114 115 116 117 118 119 |
# File 'lib/torch/tensor.rb', line 114 def item if numel != 1 raise Error, "only one element tensors can be converted to Ruby scalars" end to_a.first end |
#layout ⇒ Object
39 40 41 |
# File 'lib/torch/tensor.rb', line 39 def layout _layout.downcase.to_sym end |
#length ⇒ Object
mirror Python len()
109 110 111 |
# File 'lib/torch/tensor.rb', line 109 def length size(0) end |
#new ⇒ Object
unsure if this is correct
130 131 132 |
# File 'lib/torch/tensor.rb', line 130 def new Torch.empty(0, dtype: dtype) end |
#numo ⇒ Object
TODO read directly from memory
135 136 137 138 139 140 141 142 143 |
# File 'lib/torch/tensor.rb', line 135 def numo if dtype == :bool Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape) else cls = Torch._dtype_to_numo[dtype] raise Error, "Cannot convert #{dtype} to Numo" unless cls cls.from_string(_data_str).reshape(*shape) end end |
#random!(*args) ⇒ Object
parser can’t handle overlap, so need to handle manually
182 183 184 185 |
# File 'lib/torch/tensor.rb', line 182 def random!(*args) return _random!(0, *args) if args.size == 1 _random!(*args) end |
#real ⇒ Object
not a method in native_functions.yaml attribute in Python rather than method
201 202 203 |
# File 'lib/torch/tensor.rb', line 201 def real Torch.real(self) end |
#requires_grad=(requires_grad) ⇒ Object
145 146 147 |
# File 'lib/torch/tensor.rb', line 145 def requires_grad=(requires_grad) _requires_grad!(requires_grad) end |
#size(dim = nil) ⇒ Object
92 93 94 95 96 97 98 |
# File 'lib/torch/tensor.rb', line 92 def size(dim = nil) if dim _size(dim) else shape end end |
#stride(dim = nil) ⇒ Object
100 101 102 103 104 105 106 |
# File 'lib/torch/tensor.rb', line 100 def stride(dim = nil) if dim _stride(dim) else _strides end end |
#to(device = nil, dtype: nil, non_blocking: false, copy: false) ⇒ Object
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# File 'lib/torch/tensor.rb', line 68 def to(device = nil, dtype: nil, non_blocking: false, copy: false) if device.is_a?(Symbol) && !dtype dtype = device device = nil end device ||= self.device device = Device.new(device) if device.is_a?(String) dtype ||= self.dtype enum = DTYPE_TO_ENUM[dtype] raise Error, "Unknown type: #{dtype}" unless enum _to(device, enum, non_blocking, copy) end |
#to_a ⇒ Object
TODO make more performant
56 57 58 59 60 61 62 63 64 65 66 |
# File 'lib/torch/tensor.rb', line 56 def to_a arr = _flat_data if shape.empty? arr else shape[1..-1].reverse_each do |dim| arr = arr.each_slice(dim) end arr.to_a end end |
#to_f ⇒ Object
125 126 127 |
# File 'lib/torch/tensor.rb', line 125 def to_f item.to_f end |
#to_i ⇒ Object
121 122 123 |
# File 'lib/torch/tensor.rb', line 121 def to_i item.to_i end |
#to_s ⇒ Object
43 44 45 |
# File 'lib/torch/tensor.rb', line 43 def to_s inspect end |
#type(dtype) ⇒ Object
149 150 151 152 153 154 155 156 157 158 |
# File 'lib/torch/tensor.rb', line 149 def type(dtype) if dtype.is_a?(Class) raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype) dtype.new(self) else enum = DTYPE_TO_ENUM[dtype] raise Error, "Invalid type: #{dtype}" unless enum _type(enum) end end |