Module: TensorStream::VariableOps
- Included in:
- Evaluator::RubyEvaluator
- Defined in:
- lib/tensor_stream/evaluator/ruby/variable_ops.rb
Overview
Collection of machine learning related ops
Class Method Summary collapse
Class Method Details
.included(klass) ⇒ Object
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# File 'lib/tensor_stream/evaluator/ruby/variable_ops.rb', line 4 def self.included(klass) klass.class_eval do register_op :variable_v2 do |_context, tensor, _inputs| value = var_read_value(tensor) raise "variable #{tensor.[:var_name]} not initalized" if value.nil? value end register_op :assign do |context, tensor, inputs| var_assign_value(tensor, inputs[0]) end register_op :assign_add, no_eval: true do |context, tensor, inputs| current_val = var_read_value(tensor) raise "variable #{tensor.[:var_name]} not initialized" if current_val.nil? eval_a, eval_b = broadcast(current_val, inputs[0]) result = multi_array_op(->(var, val) { var + val }, eval_a, eval_b) var_assign_value(tensor, result) end register_op :assign_sub do |context, tensor, inputs| current_val = var_read_value(tensor) raise "variable #{tensor.[:var_name]} not initialized" if current_val.nil? eval_a, eval_b = broadcast(current_val, inputs[0]) result = multi_array_op(->(var, val) { var - val }, eval_a, eval_b) var_assign_value(tensor, result) end register_op :save_ts do |_context, tensor, inputs| outputfile = inputs[0] inputs = tensor.inputs.dup inputs.shift variables = {} inputs.each do |savable| val = var_read_value(savable) packed_data = Zlib::Deflate.deflate(TensorStream::Packer.pack(val, savable.data_type)) variables[savable.[:var_name]] = { "shape" => shape_eval(val), "data" => Base64.strict_encode64(packed_data), } end File.write(outputfile, {"variables" => variables}.to_yaml) nil end register_op :restore_ts do |_context, tensor, inputs| inputs = inputs.dup filename = inputs.shift tensor_names = inputs input_dump = YAML.safe_load(File.read(filename), [Symbol]) vars = tensor.graph.get_collection(GraphKeys::GLOBAL_VARIABLES) vars.select! { |v| input_dump["variables"].key?(v.name) && tensor_names.include?(v.name) } vars.each do |variable| data = TensorStream::Packer.unpack(Zlib::Inflate.inflate(Base64.decode64(input_dump["variables"][variable.name]["data"])), variable.data_type) shape = input_dump["variables"][variable.name]["shape"] variable.buffer = nil var_assign_value(variable, TensorShape.reshape(data, shape)) end nil end end end |