Class: TensorStream::Train::Saver

Inherits:
Object
  • Object
show all
Includes:
OpHelper
Defined in:
lib/tensor_stream/train/saver.rb

Overview

High level class used for loading and saving variables

Instance Method Summary collapse

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Constructor Details

#initialize(var_list = nil) ⇒ Saver

Returns a new instance of Saver.



10
11
12
13
14
15
16
17
18
# File 'lib/tensor_stream/train/saver.rb', line 10

def initialize(var_list = nil)
  graph = TensorStream::Graph.get_default_graph
  vars = var_list || graph.get_collection(GraphKeys::GLOBAL_VARIABLES)

  @filename = graph["ts_filename"] || TensorStream.placeholder(:string, name: "ts_filename", shape: [])

  @save_op = _op(:save_ts, @filename, *vars)
  @restore_op = _op(:restore_ts, @filename, *vars.map(&:name))
end

Instance Method Details

#restore(session, modelpath) ⇒ Object



46
47
48
49
50
51
52
53
54
# File 'lib/tensor_stream/train/saver.rb', line 46

def restore(session, modelpath)
  meta_file = File.join(modelpath, "model.meta")
  return unless File.exist?(meta_file)

   = JSON.parse(File.read(meta_file))
  gs = ["gs"]
  filename = File.join(modelpath, ["model", gs, ".ckpt"].compact.join("-"))
  session.run(@restore_op, feed_dict: {@filename => filename})
end

#save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object



20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/tensor_stream/train/saver.rb', line 20

def save(session, outputdir, global_step: nil,
  latest_filename: nil,
  meta_graph_suffix: "meta",
  write_meta_graph: true,
  write_state: true,
  strip_default_attrs: false)
  graph = TensorStream::Graph.get_default_graph
  vars = graph.get_collection(GraphKeys::GLOBAL_VARIABLES)

  variables = {}

  gs = eval_global_step(session, global_step)

  FileUtils.mkdir_p(outputdir)
  basename = "model"
  File.write(File.join(outputdir, "#{basename}.meta"), {"gs" => gs}.to_json)
  new_filename = File.join(outputdir, [basename, gs, ".ckpt"].compact.join("-"))
  session.run(@save_op, feed_dict: {@filename => new_filename})

  if write_meta_graph
    graph_filename = "#{basename}.yaml"
    TensorStream.train.write_graph(graph, outputdir, graph_filename, serializer: :yaml)
  end
  outputdir
end