Module: TensorStream::Train::Utils
- Included in:
- TensorStream::Trainer
- Defined in:
- lib/tensor_stream/train/utils.rb
Overview
convenience methods used for training
Instance Method Summary collapse
Instance Method Details
#create_global_step(graph = nil) ⇒ Object
5 6 7 8 9 10 11 12 13 14 15 |
# File 'lib/tensor_stream/train/utils.rb', line 5 def create_global_step(graph = nil) target_graph = graph || TensorStream.get_default_graph raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil? TensorStream.variable_scope.get_variable(TensorStream::GraphKeys::GLOBAL_STEP, shape: [], dtype: :int64, initializer: TensorStream.zeros_initializer, trainable: false, collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES, TensorStream::GraphKeys::GLOBAL_STEP,]) end |
#get_global_step(graph = nil) ⇒ Object
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# File 'lib/tensor_stream/train/utils.rb', line 17 def get_global_step(graph = nil) target_graph = graph || TensorStream.get_default_graph global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP) global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty? begin target_graph.get_tensor_by_name("global_step:0") rescue TensorStream::KeyError nil end elsif global_step_tensors.size == 1 global_step_tensors[0] else TensorStream.logger.error("Multiple tensors in global_step collection.") nil end global_step_tensor end |