Module: TensorStream::VariableOps

Included in:
Evaluator::RubyEvaluator
Defined in:
lib/tensor_stream/evaluator/ruby/variable_ops.rb

Overview

Collection of machine learning related ops

Class Method Summary collapse

Class Method Details

.included(klass) ⇒ Object



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# File 'lib/tensor_stream/evaluator/ruby/variable_ops.rb', line 4

def self.included(klass)
  klass.class_eval do
    register_op :variable_v2 do |_context, tensor, _inputs|
      value = var_read_value(tensor)
      raise "variable #{tensor.options[:var_name]} not initalized" if value.nil?

      value
    end

    register_op :assign do |context, tensor, inputs|
      var_assign_value(tensor, inputs[0])
    end

    register_op :assign_add, no_eval: true do |context, tensor, inputs|
      current_val = var_read_value(tensor)

      raise "variable #{tensor.options[:var_name]} not initialized" if current_val.nil?
      eval_a, eval_b = broadcast(current_val, inputs[0])
      result = multi_array_op(->(var, val) { var + val }, eval_a, eval_b)
      var_assign_value(tensor, result)
    end

    register_op :assign_sub do |context, tensor, inputs|
      current_val = var_read_value(tensor)
      raise "variable #{tensor.options[:var_name]} not initialized" if current_val.nil?
      eval_a, eval_b = broadcast(current_val, inputs[0])
      result = multi_array_op(->(var, val) { var - val }, eval_a, eval_b)
      var_assign_value(tensor, result)
    end

    register_op :save_ts do |_context, tensor, inputs|
      outputfile = inputs[0]
      inputs = tensor.inputs.dup

      inputs.shift
      variables = {}
      inputs.each do |savable|
        val = var_read_value(savable)

        packed_data = Zlib::Deflate.deflate(TensorStream::Packer.pack(val, savable.data_type))
        variables[savable.options[:var_name]] = {
          "shape" => shape_eval(val),
          "data" => Base64.strict_encode64(packed_data),
        }
      end

      File.write(outputfile, {"variables" => variables}.to_yaml)
      nil
    end

    register_op :restore_ts do |_context, tensor, inputs|
      inputs = inputs.dup
      filename = inputs.shift
      tensor_names = inputs

      input_dump = YAML.safe_load(File.read(filename), [Symbol])
      vars = tensor.graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
      vars.select! { |v| input_dump["variables"].key?(v.name) && tensor_names.include?(v.name) }
      vars.each do |variable|
        data = TensorStream::Packer.unpack(Zlib::Inflate.inflate(Base64.decode64(input_dump["variables"][variable.name]["data"])), variable.data_type)
        shape = input_dump["variables"][variable.name]["shape"]
        variable.buffer = nil
        var_assign_value(variable, TensorShape.reshape(data, shape))
      end

      nil
    end
  end
end