Module: DNN::Layers::MathUtils
- Defined in:
- lib/dnn/core/layers/math_layers.rb
Class Method Summary collapse
- .align_ndim(shape1, shape2) ⇒ Object
- .broadcast_to(x, target_shape) ⇒ Object
- .sum_to(x, target_shape) ⇒ Object
Class Method Details
.align_ndim(shape1, shape2) ⇒ Object
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
# File 'lib/dnn/core/layers/math_layers.rb', line 6 def align_ndim(shape1, shape2) if shape1.length < shape2.length shape2.length.times do |axis| unless shape1[axis] == shape2[axis] shape1.insert(axis, 1) end end elsif shape1.length > shape2.length shape1.length.times do |axis| unless shape1[axis] == shape2[axis] shape2.insert(axis, 1) end end end [shape1, shape2] end |
.broadcast_to(x, target_shape) ⇒ Object
23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
# File 'lib/dnn/core/layers/math_layers.rb', line 23 def broadcast_to(x, target_shape) return x if x.shape == target_shape x_shape, target_shape = align_ndim(x.shape, target_shape) x = x.reshape(*x_shape) x_shape.length.times do |axis| unless x.shape[axis] == target_shape[axis] tmp = x (target_shape[axis] - 1).times do x = x.concatenate(tmp, axis: axis) end end end x end |
.sum_to(x, target_shape) ⇒ Object
38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/dnn/core/layers/math_layers.rb', line 38 def sum_to(x, target_shape) return x if x.shape == target_shape x_shape, target_shape = align_ndim(x.shape, target_shape) x = x.reshape(*x_shape) x_shape.length.times do |axis| unless x.shape[axis] == target_shape[axis] x = x.sum(axis: axis, keepdims: true) end end x end |