Class: TensorStream::Freezer
- Inherits:
-
Object
- Object
- TensorStream::Freezer
- Includes:
- OpHelper
- Defined in:
- lib/tensor_stream/utils/freezer.rb
Instance Method Summary collapse
-
#convert(session, checkpoint_folder, output_file) ⇒ Object
Utility class to convert variables to constants for production deployment.
Methods included from OpHelper
#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal
Instance Method Details
#convert(session, checkpoint_folder, output_file) ⇒ Object
Utility class to convert variables to constants for production deployment
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# File 'lib/tensor_stream/utils/freezer.rb', line 8 def convert(session, checkpoint_folder, output_file) model_file = File.join(checkpoint_folder, "model.yaml") TensorStream.graph.as_default do |current_graph| YamlLoader.new.load_from_string(File.read(model_file)) saver = TensorStream::Train::Saver.new saver.restore(session, checkpoint_folder) # collect all assign ops and remove them from the graph remove_nodes = Set.new(current_graph.nodes.values.select { |op| op.is_a?(TensorStream::Operation) && op.operation == :assign }.map { |op| op.consumers.to_a }.flatten.uniq) output_buffer = TensorStream::Yaml.new.get_string(current_graph) { |graph, node_key| node = graph.get_tensor_by_name(node_key) case node.operation when :variable_v2 value = Evaluator.read_variable(node.graph, node.[:var_name]) if value.nil? raise "#{node.[:var_name]} has no value" end = { value: value, data_type: node.data_type, shape: shape_eval(value), } const_op = TensorStream::Operation.new(current_graph, inputs: [], options: ) const_op.name = node.name const_op.operation = :const const_op.data_type = node.data_type const_op.shape = TensorShape.new(shape_eval(value)) const_op when :assign nil else remove_nodes.include?(node.name) ? nil : node end } File.write(output_file, output_buffer) end end |