Module: DNN::Layers::MathUtils

Defined in:
lib/dnn/core/layers/math_layers.rb

Class Method Summary collapse

Class Method Details

.align_ndim(shape1, shape2) ⇒ Object



6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# File 'lib/dnn/core/layers/math_layers.rb', line 6

def align_ndim(shape1, shape2)
  if shape1.length < shape2.length
    shape2.length.times do |axis|
      unless shape1[axis] == shape2[axis]
        shape1.insert(axis, 1)
      end
    end
  elsif shape1.length > shape2.length
    shape1.length.times do |axis|
      unless shape1[axis] == shape2[axis]
        shape2.insert(axis, 1)
      end
    end
  end
  [shape1, shape2]
end

.broadcast_to(x, target_shape) ⇒ Object



23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/dnn/core/layers/math_layers.rb', line 23

def broadcast_to(x, target_shape)
  return x if x.shape == target_shape
  x_shape, target_shape = align_ndim(x.shape, target_shape)
  x = x.reshape(*x_shape)
  x_shape.length.times do |axis|
    unless x.shape[axis] == target_shape[axis]
      tmp = x
      (target_shape[axis] - 1).times do
        x = x.concatenate(tmp, axis: axis)
      end
    end
  end
  x
end

.sum_to(x, target_shape) ⇒ Object



38
39
40
41
42
43
44
45
46
47
48
# File 'lib/dnn/core/layers/math_layers.rb', line 38

def sum_to(x, target_shape)
  return x if x.shape == target_shape
  x_shape, target_shape = align_ndim(x.shape, target_shape)
  x = x.reshape(*x_shape)
  x_shape.length.times do |axis|
    unless x.shape[axis] == target_shape[axis]
      x = x.sum(axis: axis, keepdims: true)
    end
  end
  x
end