Class: TensorStream::TensorShape
- Inherits:
-
Object
- Object
- TensorStream::TensorShape
- Defined in:
- lib/tensor_stream/tensor_shape.rb
Overview
class that defines a shape for TensorFlow compatibility
Instance Attribute Summary collapse
-
#rank ⇒ Object
Returns the value of attribute rank.
-
#shape ⇒ Object
Returns the value of attribute shape.
Class Method Summary collapse
- .fix_inferred_elements(shape, total_size) ⇒ Object
- .infer_shape(shape_a, shape_b) ⇒ Object
- .reshape(arr, new_shape) ⇒ Object
Instance Method Summary collapse
- #[](index) ⇒ Object
- #as_dimension(value) ⇒ Object
-
#assert_compatible_with(other) ⇒ Object
Raises an exception if ‘other` is not compatible with this shape.
- #compatible_with?(other) ⇒ Boolean
- #fully_defined? ⇒ Boolean
-
#initialize(shape, rank = nil) ⇒ TensorShape
constructor
A new instance of TensorShape.
- #known? ⇒ Boolean
- #merge_with(other) ⇒ Object
- #ndims ⇒ Object
- #scalar? ⇒ Boolean
- #to_s ⇒ Object
- #value ⇒ Object
Constructor Details
#initialize(shape, rank = nil) ⇒ TensorShape
Returns a new instance of TensorShape.
6 7 8 9 |
# File 'lib/tensor_stream/tensor_shape.rb', line 6 def initialize(shape, rank = nil) @shape = shape @rank = rank.nil? && shape ? shape.size : rank end |
Instance Attribute Details
#rank ⇒ Object
Returns the value of attribute rank.
4 5 6 |
# File 'lib/tensor_stream/tensor_shape.rb', line 4 def rank @rank end |
#shape ⇒ Object
Returns the value of attribute shape.
4 5 6 |
# File 'lib/tensor_stream/tensor_shape.rb', line 4 def shape @shape end |
Class Method Details
.fix_inferred_elements(shape, total_size) ⇒ Object
121 122 123 124 125 126 127 128 |
# File 'lib/tensor_stream/tensor_shape.rb', line 121 def self.fix_inferred_elements(shape, total_size) return shape if shape.empty? return nil if shape[0].is_a?(Tensor) current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product } inferred_size = total_size.nil? ? nil : total_size / current_size shape.map { |s| s == -1 ? inferred_size : s } end |
.infer_shape(shape_a, shape_b) ⇒ Object
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# File 'lib/tensor_stream/tensor_shape.rb', line 76 def self.infer_shape(shape_a, shape_b) return nil if shape_a.nil? || shape_b.nil? return shape_a if shape_b.empty? return shape_b if shape_a.empty? return shape_a if shape_a == shape_b return shape_b if shape_b.size > shape_a.size return shape_a if shape_a.size > shape_b.size reversed_a = shape_a.reverse reversed_b = shape_b.reverse reversed_a.each_with_index.collect { |s, index| next s if index >= reversed_b.size next nil if s.nil? || reversed_b[index].nil? next nil if s.is_a?(Tensor) || reversed_b[index].is_a?(Tensor) next reversed_b[index] if reversed_b[index] > s s }.reverse end |
.reshape(arr, new_shape) ⇒ Object
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# File 'lib/tensor_stream/tensor_shape.rb', line 97 def self.reshape(arr, new_shape) arr = arr.is_a?(Array) ? arr.flatten : [arr] new_shape = new_shape.is_a?(TensorShape) ? new_shape.shape : new_shape new_shape = TensorShape.fix_inferred_elements(new_shape, arr.size) return arr[0] if arr.size == 1 && new_shape.empty? new_shape = new_shape.dup s = new_shape.shift if new_shape.size.zero? raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s return arr end dim = (arr.size / s) return arr if dim.zero? arr.each_slice(dim).collect do |slice| reshape(slice, new_shape.dup) end end |
Instance Method Details
#[](index) ⇒ Object
20 21 22 23 |
# File 'lib/tensor_stream/tensor_shape.rb', line 20 def [](index) new_shape = @shape[index] TensorShape.new(@shape[index]) end |
#as_dimension(value) ⇒ Object
62 63 64 |
# File 'lib/tensor_stream/tensor_shape.rb', line 62 def as_dimension(value) value.is_a?(TensorShape) ? value.shape : value end |
#assert_compatible_with(other) ⇒ Object
Raises an exception if ‘other` is not compatible with this shape.
72 73 74 |
# File 'lib/tensor_stream/tensor_shape.rb', line 72 def assert_compatible_with(other) raise TensorStream::ValueError, "Dimensions #{self} and #{other} are not compatible" unless compatible_with?(other) end |
#compatible_with?(other) ⇒ Boolean
56 57 58 59 60 |
# File 'lib/tensor_stream/tensor_shape.rb', line 56 def compatible_with?(other) other = as_dimension(other) shape.nil? || other.nil? || shape == other end |
#fully_defined? ⇒ Boolean
42 43 44 |
# File 'lib/tensor_stream/tensor_shape.rb', line 42 def fully_defined? known? end |
#known? ⇒ Boolean
33 34 35 36 37 38 39 40 |
# File 'lib/tensor_stream/tensor_shape.rb', line 33 def known? return false if shape.nil? a_shape = shape.is_a?(Array) ? shape : [shape] a_shape.each { |s| return false if s.nil? || s < 0 } true end |
#merge_with(other) ⇒ Object
46 47 48 49 50 51 52 53 54 |
# File 'lib/tensor_stream/tensor_shape.rb', line 46 def merge_with(other) assert_compatible_with(other) if @shape.nil? TensorShape.new(other) else TensorShape.new(@shape) end end |
#ndims ⇒ Object
25 26 27 |
# File 'lib/tensor_stream/tensor_shape.rb', line 25 def ndims shape ? shape.size : nil end |
#scalar? ⇒ Boolean
29 30 31 |
# File 'lib/tensor_stream/tensor_shape.rb', line 29 def scalar? known? && shape.size.zero? end |
#to_s ⇒ Object
11 12 13 14 15 16 17 18 |
# File 'lib/tensor_stream/tensor_shape.rb', line 11 def to_s return "?" if @shape.nil? dimensions = @shape.collect { |r| "Dimension(#{r})" }.join(",") "TensorShape([#{dimensions}])" end |
#value ⇒ Object
66 67 68 |
# File 'lib/tensor_stream/tensor_shape.rb', line 66 def value shape end |