Class: TensorStream::Protobuf
- Inherits:
-
Object
- Object
- TensorStream::Protobuf
- Defined in:
- lib/tensor_stream/graph_deserializers/protobuf.rb
Overview
A .pb graph deserializer
Instance Method Summary collapse
- #evaluate_tensor_node(node) ⇒ Object
-
#initialize ⇒ Protobuf
constructor
A new instance of Protobuf.
-
#load(pbfile) ⇒ Object
parsers a protobuf file and spits out a ruby hash.
- #load_from_string(buffer) ⇒ Object
- #map_type_to_ts(attr_value) ⇒ Object
- #options_evaluator(node) ⇒ Object
- #parse_value(value_node) ⇒ Object
Constructor Details
#initialize ⇒ Protobuf
Returns a new instance of Protobuf.
6 7 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 6 def initialize end |
Instance Method Details
#evaluate_tensor_node(node) ⇒ Object
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 31 def evaluate_tensor_node(node) if !node["shape"].empty? && node["tensor_content"] content = node["tensor_content"] unpacked = eval(%("#{content}")) if node["dtype"] == "DT_FLOAT" TensorShape.reshape(unpacked.unpack("f*"), node["shape"]) elsif node["dtype"] == "DT_INT32" TensorShape.reshape(unpacked.unpack("l*"), node["shape"]) elsif node["dtype"] == "DT_STRING" node["string_val"] else raise "unknown dtype #{node["dtype"]}" end else val = if node["dtype"] == "DT_FLOAT" node["float_val"] ? node["float_val"].to_f : [] elsif node["dtype"] == "DT_INT32" node["int_val"] ? node["int_val"].to_i : [] elsif node["dtype"] == "DT_STRING" node["string_val"] else raise "unknown dtype #{node["dtype"]}" end if node["shape"] == [1] [val] else val end end end |
#load(pbfile) ⇒ Object
parsers a protobuf file and spits out a ruby hash
16 17 18 19 20 21 22 23 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 16 def load(pbfile) f = File.new(pbfile, "r") lines = [] while !f.eof? && (str = f.readline.strip) lines << str end evaluate_lines(lines) end |
#load_from_string(buffer) ⇒ Object
9 10 11 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 9 def load_from_string(buffer) evaluate_lines(buffer.split("\n").map(&:strip)) end |
#map_type_to_ts(attr_value) ⇒ Object
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 65 def map_type_to_ts(attr_value) case attr_value when "DT_FLOAT" :float32 when "DT_INT32" :int32 when "DT_INT64" :int64 when "DT_STRING" :string when "DT_BOOL" :boolean else raise "unknown type #{attr_value}" end end |
#options_evaluator(node) ⇒ Object
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 82 def (node) return {} if node["attributes"].nil? node["attributes"].map { |attribute| attr_type, attr_value = attribute["value"].flat_map { |k, v| [k, v] } if attr_type == "tensor" attr_value = evaluate_tensor_node(attr_value) elsif attr_type == "type" attr_value = map_type_to_ts(attr_value) elsif attr_type == "b" attr_value = attr_value == "true" end [attribute["key"], attr_value] }.to_h end |
#parse_value(value_node) ⇒ Object
25 26 27 28 29 |
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 25 def parse_value(value_node) return unless value_node["tensor"] evaluate_tensor_node(value_node["tensor"]) end |