Class: TensorStream::Graph
- Inherits:
-
Object
- Object
- TensorStream::Graph
show all
- Includes:
- OpHelper
- Defined in:
- lib/tensor_stream/graph.rb
Overview
A class that defines a TensorStream graph
Instance Attribute Summary collapse
Class Method Summary
collapse
Instance Method Summary
collapse
Methods included from OpHelper
#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal
Constructor Details
#initialize ⇒ Graph
Returns a new instance of Graph.
Instance Attribute Details
#collections ⇒ Object
Returns the value of attribute collections.
6
7
8
|
# File 'lib/tensor_stream/graph.rb', line 6
def collections
@collections
end
|
#constants ⇒ Object
Returns the value of attribute constants.
6
7
8
|
# File 'lib/tensor_stream/graph.rb', line 6
def constants
@constants
end
|
#eager_execution ⇒ Object
Returns the value of attribute eager_execution.
6
7
8
|
# File 'lib/tensor_stream/graph.rb', line 6
def eager_execution
@eager_execution
end
|
#node_keys ⇒ Object
Returns the value of attribute node_keys.
7
8
9
|
# File 'lib/tensor_stream/graph.rb', line 7
def node_keys
@node_keys
end
|
#nodes ⇒ Object
Returns the value of attribute nodes.
6
7
8
|
# File 'lib/tensor_stream/graph.rb', line 6
def nodes
@nodes
end
|
#random_seed ⇒ Object
Returns the value of attribute random_seed.
6
7
8
|
# File 'lib/tensor_stream/graph.rb', line 6
def random_seed
@random_seed
end
|
Class Method Details
.create_default ⇒ Object
75
76
77
|
# File 'lib/tensor_stream/graph.rb', line 75
def self.create_default
Thread.current[:tensor_stream_current_graph] = TensorStream::Graph.new
end
|
.get_default_graph ⇒ Object
71
72
73
|
# File 'lib/tensor_stream/graph.rb', line 71
def self.get_default_graph
Thread.current[:tensor_stream_current_graph] || create_default
end
|
.parse_from_string(buffer) ⇒ Object
Instance Method Details
#[](name) ⇒ Object
122
123
124
|
# File 'lib/tensor_stream/graph.rb', line 122
def [](name)
get_node(name)
end
|
#add_node(node, name = nil) ⇒ Object
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
# File 'lib/tensor_stream/graph.rb', line 88
def add_node(node, name = nil)
raise "Placeholder cannot be used when eager_execution is enabled" if @eager_execution && node.is_a?(Placeholder)
if name.nil?
node.name = if @nodes[node.name]
uniqunify(node.name)
else
node.name
end
end
node.device = get_device_scope
@node_keys << node.name
@nodes[node.name] = node
@constants[node.name] = node if node.is_const
node.send(:propagate_outputs)
node.send(:propagate_consumer, node)
end
|
#add_node!(name, node) ⇒ Object
126
127
128
129
|
# File 'lib/tensor_stream/graph.rb', line 126
def add_node!(name, node)
@nodes[name] = node
node
end
|
#add_op(operation, *args) ⇒ Object
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
# File 'lib/tensor_stream/graph.rb', line 131
def add_op(operation, *args)
options = if args.last.is_a?(Hash)
args.pop || {}
else
{}
end
inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }
new_op = Operation.new(self, inputs: inputs, options: options)
new_op.source = format_source(caller_locations)
new_op.operation = operation
new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
new_op.rank = new_op.shape.rank
new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join("/")
new_op.internal = options[:internal]
new_op.data_type = new_op.set_data_type(options[:data_type])
new_op.is_const = new_op.infer_const
new_op.given_name = new_op.name
new_op
end
|
#add_op!(operation, *args) ⇒ Object
156
157
158
|
# File 'lib/tensor_stream/graph.rb', line 156
def add_op!(operation, *args)
add_op(operation, *args).tap { |node| add_node(node) }
end
|
#add_to_collection(collection_name, val) ⇒ Object
83
84
85
86
|
# File 'lib/tensor_stream/graph.rb', line 83
def add_to_collection(collection_name, val)
@collections[collection_name.to_sym] ||= []
@collections[collection_name.to_sym] << val
end
|
#add_variable(node, options = {}) ⇒ Object
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|
# File 'lib/tensor_stream/graph.rb', line 164
def add_variable(node, options = {})
scope = _variable_scope
raise "duplicate variable detected #{node.name} and reuse=false in current scope" if @nodes[node.name] && !scope.reuse
return @nodes[node.name] if @nodes[node.name]
raise "shape is not declared for #{node.name}" if node.shape.nil?
if !options[:collections].nil? && !options[:collections].empty?
options[:collections] = [options[:collections]] unless options[:collections].is_a?(Array)
options[:collections].each { |coll| add_to_collection(coll, node) }
end
add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
add_to_collection(GraphKeys::TRAINABLE_VARIABLES, node) if node.trainable?
node
end
|
#add_variable!(node, options = {}) ⇒ Object
182
183
184
185
186
|
# File 'lib/tensor_stream/graph.rb', line 182
def add_variable!(node, options = {})
node = add_variable(node, options)
op = Graph.get_default_graph.add_op!(:variable_v2, var_name: node.name, shape: options[:shape], data_type: options[:data_type])
op
end
|
#as_default {|_self| ... } ⇒ Object
36
37
38
39
40
41
42
43
44
|
# File 'lib/tensor_stream/graph.rb', line 36
def as_default
Thread.current[:tensor_stream_current_graph_queue] ||= []
Thread.current[:tensor_stream_current_graph_queue] << Graph.get_default_graph
Thread.current[:tensor_stream_current_graph] = self
yield(self) if block_given?
Thread.current[:tensor_stream_current_graph] = Thread.current[:tensor_stream_current_graph_queue].pop
self
end
|
#control_dependencies(control_inputs = []) ⇒ Object
188
189
190
191
192
193
194
195
196
197
|
# File 'lib/tensor_stream/graph.rb', line 188
def control_dependencies(control_inputs = [])
Thread.current["ts_graph_#{object_id}"] ||= {}
Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Graph.get_default_graph.add_op!(:no_op, *control_inputs)
begin
yield
ensure
Thread.current["ts_graph_#{object_id}"][:control_dependencies].pop
end
end
|
#device(device_name) ⇒ Object
Returns a context manager that specifies the default device to use.
60
61
62
63
64
65
66
67
68
69
|
# File 'lib/tensor_stream/graph.rb', line 60
def device(device_name)
Thread.current["ts_graph_#{object_id}"] ||= {}
Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
begin
yield
ensure
Thread.current["ts_graph_#{object_id}"][:default_device].pop
end
end
|
#disable_eager_execution ⇒ Object
203
204
205
|
# File 'lib/tensor_stream/graph.rb', line 203
def disable_eager_execution
@eager_execution = false
end
|
#enable_eager_execution ⇒ Object
199
200
201
|
# File 'lib/tensor_stream/graph.rb', line 199
def enable_eager_execution
@eager_execution = true
end
|
#executing_eagerly? ⇒ Boolean
207
208
209
|
# File 'lib/tensor_stream/graph.rb', line 207
def executing_eagerly?
@eager_execution
end
|
#get_collection(name, _options = {}) ⇒ Object
79
80
81
|
# File 'lib/tensor_stream/graph.rb', line 79
def get_collection(name, _options = {})
@collections[name.to_sym]
end
|
#get_const_counter ⇒ Object
238
239
240
241
242
243
244
245
|
# File 'lib/tensor_stream/graph.rb', line 238
def get_const_counter
@const_counter ||= 0
name = @const_counter.zero? ? "" : "_#{@const_counter}"
@const_counter += 1
name
end
|
#get_dependency_scope ⇒ Object
254
255
256
257
258
|
# File 'lib/tensor_stream/graph.rb', line 254
def get_dependency_scope
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
return nil if graph_thread_storage.nil? || graph_thread_storage[:control_dependencies].nil?
graph_thread_storage[:control_dependencies].last
end
|
#get_device_scope ⇒ Object
260
261
262
263
264
|
# File 'lib/tensor_stream/graph.rb', line 260
def get_device_scope
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
graph_thread_storage[:default_device].last
end
|
#get_name_scope ⇒ Object
247
248
249
250
251
252
|
# File 'lib/tensor_stream/graph.rb', line 247
def get_name_scope
graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?
graph_thread_storage[:current_scope].join("/")
end
|
#get_node(name) ⇒ Object
112
113
114
|
# File 'lib/tensor_stream/graph.rb', line 112
def get_node(name)
@nodes[name]
end
|
#get_operation_counter ⇒ Object
211
212
213
214
215
216
217
218
219
|
# File 'lib/tensor_stream/graph.rb', line 211
def get_operation_counter
@op_counter ||= 0
name = @op_counter.zero? ? "" : "_#{@op_counter}"
@op_counter += 1
name
end
|
#get_placeholder_counter ⇒ Object
221
222
223
224
225
226
227
228
|
# File 'lib/tensor_stream/graph.rb', line 221
def get_placeholder_counter
@placeholder_counter ||= 0
@placeholder_counter += 1
return "" if @placeholder_counter == 1
"_#{@placeholder_counter}"
end
|
#get_tensor_by_name(name) ⇒ Object
116
117
118
119
120
|
# File 'lib/tensor_stream/graph.rb', line 116
def get_tensor_by_name(name)
raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)
get_node(name)
end
|
#get_var_counter ⇒ Object
230
231
232
233
234
235
236
|
# File 'lib/tensor_stream/graph.rb', line 230
def get_var_counter
@var_counter ||= 0
@var_counter += 1
return "" if @var_counter == 1
"_#{@var_counter}"
end
|
#graph_def_versions ⇒ Object
275
276
277
|
# File 'lib/tensor_stream/graph.rb', line 275
def graph_def_versions
"producer: 26"
end
|
#name_scope(name = nil) ⇒ Object
46
47
48
49
50
51
52
53
54
55
56
|
# File 'lib/tensor_stream/graph.rb', line 46
def name_scope(name = nil)
Thread.current["ts_graph_#{object_id}"] ||= {}
Thread.current["ts_graph_#{object_id}"][:current_scope] ||= []
Thread.current["ts_graph_#{object_id}"][:current_scope] << name
begin
yield get_name_scope if block_given?
ensure
Thread.current["ts_graph_#{object_id}"][:current_scope].pop
end
end
|
#node_added?(name) ⇒ Boolean
108
109
110
|
# File 'lib/tensor_stream/graph.rb', line 108
def node_added?(name)
@nodes.key?(name)
end
|
#set_operation_name(op) ⇒ Object
160
161
162
|
# File 'lib/tensor_stream/graph.rb', line 160
def set_operation_name(op)
op.operation.to_s
end
|