Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



8
9
10
# File 'lib/torch/tensor.rb', line 8

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#%(other) ⇒ Object



137
138
139
# File 'lib/torch/tensor.rb', line 137

def %(other)
  remainder(other)
end

#*(other) ⇒ Object



129
130
131
# File 'lib/torch/tensor.rb', line 129

def *(other)
  mul(other)
end

#**(other) ⇒ Object



141
142
143
# File 'lib/torch/tensor.rb', line 141

def **(other)
  pow(other)
end

#+(other) ⇒ Object

end temp operations



121
122
123
# File 'lib/torch/tensor.rb', line 121

def +(other)
  add(other)
end

#-(other) ⇒ Object



125
126
127
# File 'lib/torch/tensor.rb', line 125

def -(other)
  sub(other)
end

#-@Object



145
146
147
# File 'lib/torch/tensor.rb', line 145

def -@
  neg
end

#/(other) ⇒ Object



133
134
135
# File 'lib/torch/tensor.rb', line 133

def /(other)
  div(other)
end

#<=>(other) ⇒ Object



149
150
151
# File 'lib/torch/tensor.rb', line 149

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp



154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# File 'lib/torch/tensor.rb', line 154

def [](*indexes)
  result = self
  dim = 0
  indexes.each do |index|
    if index.is_a?(Numeric)
      result = result._select_int(dim, index)
    elsif index.is_a?(Range)
      finish = index.end
      finish += 1 unless index.exclude_end?
      result = result._slice_tensor(dim, index.begin, finish, 1)
      dim += 1
    elsif index.nil?
      result = result.unsqueeze(dim)
      dim += 1
    elsif index == true
      result = result.unsqueeze(dim)
      # TODO handle false
    else
      raise Error, "Unsupported index type: #{index.class.name}"
    end
  end
  result
end

#[]=(index, value) ⇒ Object

TODO based on python_variable_indexing.cpp

Raises:

  • (ArgumentError)


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# File 'lib/torch/tensor.rb', line 180

def []=(index, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?

  value = Torch.tensor(value) unless value.is_a?(Tensor)

  if index.is_a?(Numeric)
    copy_to(_select_int(0, index), value)
  elsif index.is_a?(Range)
    finish = index.end
    finish += 1 unless index.exclude_end?
    copy_to(_slice_tensor(0, index.begin, finish, 1), value)
  else
    raise Error, "Unsupported index type: #{index.class.name}"
  end
end

#add!(value = 1, other) ⇒ Object

start temp operations



91
92
93
94
95
96
97
98
# File 'lib/torch/tensor.rb', line 91

def add!(value = 1, other)
  if other.is_a?(Numeric)
    _add__scalar(other, value)
  else
    # need to use alpha for sparse tensors instead of multiplying
    _add__tensor(other, value)
  end
end

#backward(gradient = nil) ⇒ Object



64
65
66
# File 'lib/torch/tensor.rb', line 64

def backward(gradient = nil)
  _backward(gradient)
end

#dtypeObject

Raises:



12
13
14
15
16
# File 'lib/torch/tensor.rb', line 12

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#itemObject



52
53
54
55
56
57
# File 'lib/torch/tensor.rb', line 52

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  _flat_data.first
end

#layoutObject



18
19
20
# File 'lib/torch/tensor.rb', line 18

def layout
  _layout.downcase.to_sym
end

#mul!(other) ⇒ Object



100
101
102
103
104
105
106
# File 'lib/torch/tensor.rb', line 100

def mul!(other)
  if other.is_a?(Numeric)
    _mul__scalar(other)
  else
    _mul__tensor(other)
  end
end

#newObject

unsure if this is correct



60
61
62
# File 'lib/torch/tensor.rb', line 60

def new
  Torch.empty(0, dtype: dtype)
end

#new_ones(*size, **options) ⇒ Object



75
76
77
# File 'lib/torch/tensor.rb', line 75

def new_ones(*size, **options)
  Torch.ones_like(Torch.empty(*size), **options)
end

#numoObject

TODO read directly from memory

Raises:



69
70
71
72
73
# File 'lib/torch/tensor.rb', line 69

def numo
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.cast(_flat_data).reshape(*shape)
end

#requires_grad!(requires_grad = true) ⇒ Object



79
80
81
# File 'lib/torch/tensor.rb', line 79

def requires_grad!(requires_grad = true)
  _requires_grad!(requires_grad)
end

#shapeObject



44
45
46
# File 'lib/torch/tensor.rb', line 44

def shape
  dim.times.map { |i| size(i) }
end

#size(dim = nil) ⇒ Object



36
37
38
39
40
41
42
# File 'lib/torch/tensor.rb', line 36

def size(dim = nil)
  if dim
    _size_int(dim)
  else
    shape
  end
end

#to(device, non_blocking: false, copy: false) ⇒ Object

TODO support dtype



31
32
33
34
# File 'lib/torch/tensor.rb', line 31

def to(device, non_blocking: false, copy: false)
  device = Device.new(device) if device.is_a?(String)
  _to(device, _dtype, non_blocking, copy)
end

#to_aObject



26
27
28
# File 'lib/torch/tensor.rb', line 26

def to_a
  reshape_arr(_flat_data, shape)
end

#to_sObject



22
23
24
# File 'lib/torch/tensor.rb', line 22

def to_s
  inspect
end

#type(dtype) ⇒ Object

Raises:



83
84
85
86
87
# File 'lib/torch/tensor.rb', line 83

def type(dtype)
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum
  _type(enum)
end

#view(*size) ⇒ Object



48
49
50
# File 'lib/torch/tensor.rb', line 48

def view(*size)
  _view(size)
end