Class: TensorFlow::Keras::Layers::Dense

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/keras/layers/dense.rb

Instance Method Summary collapse

Constructor Details

#initialize(units, activation: nil, use_bias: true, kernel_initializer: "glorot_uniform", bias_initializer: "zeros", dtype: :float) ⇒ Dense

Returns a new instance of Dense.



5
6
7
8
9
10
11
12
13
# File 'lib/tensorflow/keras/layers/dense.rb', line 5

def initialize(units, activation: nil, use_bias: true, kernel_initializer: "glorot_uniform", bias_initializer: "zeros", dtype: :float)
  @units = units
  @activation = activation
  @use_bias = use_bias
  @kernel_initializer = kernel_initializer
  @bias_initializer = bias_initializer
  @dtype = dtype
  @built = false
end

Instance Method Details

#build(input_shape) ⇒ Object



15
16
17
18
19
20
21
22
23
24
25
26
27
28
# File 'lib/tensorflow/keras/layers/dense.rb', line 15

def build(input_shape)
  last_dim = input_shape.last
  @kernel = Utils.add_weight(name: "kernel", shape: [last_dim, @units], initializer: @kernel_initializer, dtype: @dtype)

  if @use_bias
    @bias = Utils.add_weight(name: "bias", shape: [@units], initializer: @bias_initializer, dtype: @dtype)
  else
    @bias = nil
  end

  @output_shape = [last_dim, @units]

  @built = true
end

#call(inputs) ⇒ Object



38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/tensorflow/keras/layers/dense.rb', line 38

def call(inputs)
  build(inputs.shape) unless @built

  rank = inputs.shape.size

  if rank > 2
    raise Error, "Rank > 2 not supported yet"
  else
    inputs = TensorFlow.cast(inputs, @dtype)
    outputs = TensorFlow.matmul(inputs, @kernel)
  end

  if @use_bias
    outputs = NN.bias_add(outputs, @bias)
  end

  case @activation
  when "relu"
    NN.relu(outputs)
  when "softmax"
    NN.softmax(outputs)
  when nil
    outputs
  else
    raise "Unknown activation: #{@activation}"
  end
end

#count_paramsObject



34
35
36
# File 'lib/tensorflow/keras/layers/dense.rb', line 34

def count_params
  @units + @kernel.shape.inject(&:*)
end

#output_shapeObject



30
31
32
# File 'lib/tensorflow/keras/layers/dense.rb', line 30

def output_shape
  @output_shape
end