Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Enumerable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Constant Summary

Constants included from Inspector

Inspector::PRINT_OPTS

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



29
30
31
# File 'lib/torch/tensor.rb', line 29

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#<=>(other) ⇒ Object

TODO better compare?



161
162
163
# File 'lib/torch/tensor.rb', line 161

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html



167
168
169
170
# File 'lib/torch/tensor.rb', line 167

def [](*indexes)
  indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
  _index(indexes)
end

#[]=(*indexes, value) ⇒ Object

based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html

Raises:

  • (ArgumentError)


174
175
176
177
178
179
# File 'lib/torch/tensor.rb', line 174

def []=(*indexes, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
  indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
  _index_put_custom(indexes, value)
end

#coerce(other) ⇒ Object



205
206
207
208
209
210
211
# File 'lib/torch/tensor.rb', line 205

def coerce(other)
  if other.is_a?(Numeric)
    [Torch.tensor(other), self]
  else
    raise TypeError, "#{self.class} can't be coerced into #{other.class}"
  end
end

#cpuObject



84
85
86
# File 'lib/torch/tensor.rb', line 84

def cpu
  to("cpu")
end

#cudaObject



88
89
90
# File 'lib/torch/tensor.rb', line 88

def cuda
  to("cuda")
end

#dtypeObject

Raises:



33
34
35
36
37
# File 'lib/torch/tensor.rb', line 33

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#dupObject



187
188
189
190
191
# File 'lib/torch/tensor.rb', line 187

def dup
  Torch.no_grad do
    clone
  end
end

#eachObject



47
48
49
50
51
52
53
# File 'lib/torch/tensor.rb', line 47

def each
  return enum_for(:each) unless block_given?

  size(0).times do |i|
    yield self[i]
  end
end

#imagObject

not a method in native_functions.yaml attribute in Python rather than method



195
196
197
# File 'lib/torch/tensor.rb', line 195

def imag
  Torch.imag(self)
end

#itemObject



114
115
116
117
118
119
# File 'lib/torch/tensor.rb', line 114

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  to_a.first
end

#layoutObject



39
40
41
# File 'lib/torch/tensor.rb', line 39

def layout
  _layout.downcase.to_sym
end

#lengthObject

mirror Python len()



109
110
111
# File 'lib/torch/tensor.rb', line 109

def length
  size(0)
end

#newObject

unsure if this is correct



130
131
132
# File 'lib/torch/tensor.rb', line 130

def new
  Torch.empty(0, dtype: dtype)
end

#numoObject

TODO read directly from memory



135
136
137
138
139
140
141
142
143
# File 'lib/torch/tensor.rb', line 135

def numo
  if dtype == :bool
    Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
  else
    cls = Torch._dtype_to_numo[dtype]
    raise Error, "Cannot convert #{dtype} to Numo" unless cls
    cls.from_string(_data_str).reshape(*shape)
  end
end

#random!(*args) ⇒ Object

parser can’t handle overlap, so need to handle manually



182
183
184
185
# File 'lib/torch/tensor.rb', line 182

def random!(*args)
  return _random!(0, *args) if args.size == 1
  _random!(*args)
end

#realObject

not a method in native_functions.yaml attribute in Python rather than method



201
202
203
# File 'lib/torch/tensor.rb', line 201

def real
  Torch.real(self)
end

#requires_grad=(requires_grad) ⇒ Object



145
146
147
# File 'lib/torch/tensor.rb', line 145

def requires_grad=(requires_grad)
  _requires_grad!(requires_grad)
end

#size(dim = nil) ⇒ Object



92
93
94
95
96
97
98
# File 'lib/torch/tensor.rb', line 92

def size(dim = nil)
  if dim
    _size(dim)
  else
    shape
  end
end

#stride(dim = nil) ⇒ Object



100
101
102
103
104
105
106
# File 'lib/torch/tensor.rb', line 100

def stride(dim = nil)
  if dim
    _stride(dim)
  else
    _strides
  end
end

#to(device = nil, dtype: nil, non_blocking: false, copy: false) ⇒ Object

Raises:



68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# File 'lib/torch/tensor.rb', line 68

def to(device = nil, dtype: nil, non_blocking: false, copy: false)
  if device.is_a?(Symbol) && !dtype
    dtype = device
    device = nil
  end

  device ||= self.device
  device = Device.new(device) if device.is_a?(String)

  dtype ||= self.dtype
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum

  _to(device, enum, non_blocking, copy)
end

#to_aObject

TODO make more performant



56
57
58
59
60
61
62
63
64
65
66
# File 'lib/torch/tensor.rb', line 56

def to_a
  arr = _flat_data
  if shape.empty?
    arr
  else
    shape[1..-1].reverse_each do |dim|
      arr = arr.each_slice(dim)
    end
    arr.to_a
  end
end

#to_fObject



125
126
127
# File 'lib/torch/tensor.rb', line 125

def to_f
  item.to_f
end

#to_iObject



121
122
123
# File 'lib/torch/tensor.rb', line 121

def to_i
  item.to_i
end

#to_sObject



43
44
45
# File 'lib/torch/tensor.rb', line 43

def to_s
  inspect
end

#type(dtype) ⇒ Object



149
150
151
152
153
154
155
156
157
158
# File 'lib/torch/tensor.rb', line 149

def type(dtype)
  if dtype.is_a?(Class)
    raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
    dtype.new(self)
  else
    enum = DTYPE_TO_ENUM[dtype]
    raise Error, "Invalid type: #{dtype}" unless enum
    _type(enum)
  end
end