Class: TensorStream::Train::Saver
- Inherits:
-
Object
- Object
- TensorStream::Train::Saver
- Includes:
- OpHelper
- Defined in:
- lib/tensor_stream/train/saver.rb
Overview
High level class used for loading and saving variables
Instance Method Summary collapse
-
#initialize(var_list = nil) ⇒ Saver
constructor
A new instance of Saver.
- #restore(session, modelpath) ⇒ Object
- #save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object
Methods included from OpHelper
#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal
Constructor Details
#initialize(var_list = nil) ⇒ Saver
Returns a new instance of Saver.
10 11 12 13 14 15 16 17 18 |
# File 'lib/tensor_stream/train/saver.rb', line 10 def initialize(var_list = nil) graph = TensorStream::Graph.get_default_graph vars = var_list || graph.get_collection(GraphKeys::GLOBAL_VARIABLES) @filename = graph["ts_filename"] || TensorStream.placeholder(:string, name: "ts_filename", shape: []) @save_op = _op(:save_ts, @filename, *vars) @restore_op = _op(:restore_ts, @filename, *vars.map(&:name)) end |
Instance Method Details
#restore(session, modelpath) ⇒ Object
46 47 48 49 50 51 52 53 54 |
# File 'lib/tensor_stream/train/saver.rb', line 46 def restore(session, modelpath) = File.join(modelpath, "model.meta") return unless File.exist?() = JSON.parse(File.read()) gs = ["gs"] filename = File.join(modelpath, ["model", gs, ".ckpt"].compact.join("-")) session.run(@restore_op, feed_dict: {@filename => filename}) end |
#save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# File 'lib/tensor_stream/train/saver.rb', line 20 def save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false) graph = TensorStream::Graph.get_default_graph vars = graph.get_collection(GraphKeys::GLOBAL_VARIABLES) variables = {} gs = eval_global_step(session, global_step) FileUtils.mkdir_p(outputdir) basename = "model" File.write(File.join(outputdir, "#{basename}.meta"), {"gs" => gs}.to_json) new_filename = File.join(outputdir, [basename, gs, ".ckpt"].compact.join("-")) session.run(@save_op, feed_dict: {@filename => new_filename}) if graph_filename = "#{basename}.yaml" TensorStream.train.write_graph(graph, outputdir, graph_filename, serializer: :yaml) end outputdir end |