Class: TensorStream::YamlLoader
- Inherits:
-
Object
- Object
- TensorStream::YamlLoader
- Defined in:
- lib/tensor_stream/graph_deserializers/yaml_loader.rb
Overview
Class for deserialization from a YAML file
Instance Method Summary collapse
-
#initialize(graph = nil) ⇒ YamlLoader
constructor
A new instance of YamlLoader.
-
#load_from_file(filename) ⇒ Object
Loads a model Yaml file and builds the model from it.
-
#load_from_string(buffer) ⇒ Object
Loads a model Yaml file and builds the model from it.
Constructor Details
#initialize(graph = nil) ⇒ YamlLoader
Returns a new instance of YamlLoader.
5 6 7 |
# File 'lib/tensor_stream/graph_deserializers/yaml_loader.rb', line 5 def initialize(graph = nil) @graph = graph || TensorStream.get_default_graph end |
Instance Method Details
#load_from_file(filename) ⇒ Object
Loads a model Yaml file and builds the model from it
Args: filename: String - Location of Yaml file
Returns: Graph where model is restored to
16 17 18 |
# File 'lib/tensor_stream/graph_deserializers/yaml_loader.rb', line 16 def load_from_file(filename) load_from_string(File.read(filename)) end |
#load_from_string(buffer) ⇒ Object
Loads a model Yaml file and builds the model from it
Args: buffer: String - String in Yaml format of the model
Returns: Graph where model is restored to
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# File 'lib/tensor_stream/graph_deserializers/yaml_loader.rb', line 27 def load_from_string(buffer) serialized_ops = YAML.safe_load(buffer, [Symbol], [], true) serialized_ops.each do |op_def| inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) } = {} new_var = nil if op_def[:op].to_sym == :variable_v2 new_var = Variable.new(op_def.dig(:attrs, :data_type)) = {} [:name] = op_def.dig(:attrs, :var_name) new_var.prepare(nil, nil, TensorStream.get_variable_scope, ) @graph.add_variable(new_var, ) end new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge()) new_op.operation = op_def[:op].to_sym new_op.name = op_def[:name] new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op)) new_op.rank = new_op.shape.rank new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type)) new_op.is_const = new_op.infer_const new_op.given_name = new_op.name new_var.op = new_op if new_var @graph.add_node(new_op) end @graph end |