Class: TensorFlow::Data::BatchDataset

Inherits:
Dataset
  • Object
show all
Defined in:
lib/tensorflow/data/batch_dataset.rb

Instance Attribute Summary

Attributes inherited from Dataset

#output_shapes, #output_types

Instance Method Summary collapse

Methods inherited from Dataset

#batch, #each, from_tensor_slices, #shuffle, #to_ptr

Constructor Details

#initialize(input_dataset, batch_size, drop_remainder) ⇒ BatchDataset

Returns a new instance of BatchDataset.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
# File 'lib/tensorflow/data/batch_dataset.rb', line 4

def initialize(input_dataset, batch_size, drop_remainder)
  @input_dataset = input_dataset # keep reference for memory
  @output_types = input_dataset.output_types
  @output_shapes = input_dataset.output_shapes.map { |s| [batch_size] + s }

  variant_tensor = RawOps.batch_dataset_v2(
    input_dataset: input_dataset,
    batch_size: TensorFlow.convert_to_tensor(batch_size, dtype: :int64),
    drop_remainder: drop_remainder,
    output_types: @output_types,
    output_shapes: @output_shapes
  )
  super(variant_tensor)
end