3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
|
# File 'lib/neighbor/model.rb', line 3
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
if attribute_names.empty?
raise ArgumentError, "has_neighbors requires an attribute name"
end
attribute_names.map!(&:to_sym)
class_eval do
@neighbor_attributes ||= {}
if @neighbor_attributes.empty?
def self.neighbor_attributes
parent_attributes =
if superclass.respond_to?(:neighbor_attributes)
superclass.neighbor_attributes
else
{}
end
parent_attributes.merge(@neighbor_attributes || {})
end
end
attribute_names.each do |attribute_name|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
end
if ActiveRecord::VERSION::STRING.to_f >= 7.2
decorate_attributes(attribute_names) do |name, cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
end
else
attribute_names.each do |attribute_name|
attribute attribute_name do |cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
end
end
end
if normalize
if ActiveRecord::VERSION::STRING.to_f >= 7.1
attribute_names.each do |attribute_name|
normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
end
else
attribute_names.each do |attribute_name|
attribute attribute_name do |cast_type|
Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
end
end
end
end
return if @neighbor_attributes.size != attribute_names.size
validate do
adapter = Utils.adapter(self.class)
self.class.neighbor_attributes.each do |k, v|
value = read_attribute(k)
next if value.nil?
column_info = self.class.columns_hash[k.to_s]
dimensions = v[:dimensions]
dimensions ||= column_info&.limit unless column_info&.type == :binary
type = v[:type] || Utils.type(adapter, column_info&.type)
if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
errors.add(k, "must have #{dimensions} dimensions")
end
if !Neighbor::Utils.validate_finite(value, type)
errors.add(k, "must have finite values")
end
end
end
scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
attribute_name = attribute_name.to_sym
options = neighbor_attributes[attribute_name]
raise ArgumentError, "Invalid attribute" unless options
normalize = options[:normalize]
dimensions = options[:dimensions]
type = options[:type]
return none if vector.nil?
distance = distance.to_s
column_info = columns_hash[attribute_name.to_s]
column_type = column_info&.type
adapter = Neighbor::Utils.adapter(klass)
if type && adapter != :sqlite
raise ArgumentError, "type only works with SQLite"
end
operator = Neighbor::Utils.operator(adapter, column_type, distance)
raise ArgumentError, "Invalid distance: #{distance}" unless operator
normalize_required = Utils.normalize_required?(adapter, column_type)
if distance == "cosine" && normalize_required && normalize.nil?
raise Neighbor::Error, "Set normalize for cosine distance with cube"
end
column_attribute = klass.type_for_attribute(attribute_name)
vector = column_attribute.cast(vector)
dimensions ||= column_info&.limit unless column_info&.type == :binary
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
quoted_attribute = nil
query = nil
connection_pool.with_connection do |c|
quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
query = c.quote(column_attribute.serialize(vector))
end
if !precision.nil?
if adapter != :postgresql || column_type != :vector
raise ArgumentError, "Precision not supported for this type"
end
case precision.to_s
when "half"
cast_dimensions = dimensions || column_info&.limit
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
else
raise ArgumentError, "Invalid precision"
end
end
order = Utils.order(adapter, type, operator, quoted_attribute, query)
neighbor_distance =
if distance == "cosine" && normalize_required
"POWER(#{order}, 2) / 2.0"
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
"(#{order}) * -1"
else
order
end
select_columns = select_values.any? ? [] : column_names
select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
.where.not(attribute_name => nil)
.reorder(Arel.sql(order))
}
def nearest_neighbors(attribute_name, **options)
attribute_name = attribute_name.to_sym
raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
self.class
.where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
.nearest_neighbors(attribute_name, self[attribute_name], **options)
end
end
end
|