Module: Neighbor::Model

Defined in:
lib/neighbor/model.rb

Instance Method Summary collapse

Instance Method Details

#has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil) ⇒ Object



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# File 'lib/neighbor/model.rb', line 3

def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
  if attribute_names.empty?
    raise ArgumentError, "has_neighbors requires an attribute name"
  end
  attribute_names.map!(&:to_sym)

  class_eval do
    @neighbor_attributes ||= {}

    if @neighbor_attributes.empty?
      def self.neighbor_attributes
        parent_attributes =
          if superclass.respond_to?(:neighbor_attributes)
            superclass.neighbor_attributes
          else
            {}
          end

        parent_attributes.merge(@neighbor_attributes || {})
      end
    end

    attribute_names.each do |attribute_name|
      raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
      @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
    end

    if ActiveRecord::VERSION::STRING.to_f >= 7.2
      decorate_attributes(attribute_names) do |name, cast_type|
        Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
      end
    else
      attribute_names.each do |attribute_name|
        attribute attribute_name do |cast_type|
          Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
        end
      end
    end

    if normalize
      if ActiveRecord::VERSION::STRING.to_f >= 7.1
        attribute_names.each do |attribute_name|
          normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
        end
      else
        attribute_names.each do |attribute_name|
          attribute attribute_name do |cast_type|
            Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
          end
        end
      end
    end

    return if @neighbor_attributes.size != attribute_names.size

    validate do
      adapter = Utils.adapter(self.class)

      self.class.neighbor_attributes.each do |k, v|
        value = read_attribute(k)
        next if value.nil?

        column_info = self.class.columns_hash[k.to_s]
        dimensions = v[:dimensions]
        dimensions ||= column_info&.limit unless column_info&.type == :binary
        type = v[:type] || Utils.type(adapter, column_info&.type)

        if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
          errors.add(k, "must have #{dimensions} dimensions")
        end
        if !Neighbor::Utils.validate_finite(value, type)
          errors.add(k, "must have finite values")
        end
      end
    end

    scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
      attribute_name = attribute_name.to_sym
      options = neighbor_attributes[attribute_name]
      raise ArgumentError, "Invalid attribute" unless options
      normalize = options[:normalize]
      dimensions = options[:dimensions]
      type = options[:type]

      return none if vector.nil?

      distance = distance.to_s

      column_info = columns_hash[attribute_name.to_s]
      column_type = column_info&.type

      adapter = Neighbor::Utils.adapter(klass)
      if type && adapter != :sqlite
        raise ArgumentError, "type only works with SQLite"
      end

      operator = Neighbor::Utils.operator(adapter, column_type, distance)
      raise ArgumentError, "Invalid distance: #{distance}" unless operator

      # ensure normalize set (can be true or false)
      normalize_required = Utils.normalize_required?(adapter, column_type)
      if distance == "cosine" && normalize_required && normalize.nil?
        raise Neighbor::Error, "Set normalize for cosine distance with cube"
      end

      column_attribute = klass.type_for_attribute(attribute_name)
      vector = column_attribute.cast(vector)
      dimensions ||= column_info&.limit unless column_info&.type == :binary
      Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
      vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize

      quoted_attribute = nil
      query = nil
      connection_pool.with_connection do |c|
        quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
        query = c.quote(column_attribute.serialize(vector))
      end

      if !precision.nil?
        if adapter != :postgresql || column_type != :vector
          raise ArgumentError, "Precision not supported for this type"
        end

        case precision.to_s
        when "half"
          cast_dimensions = dimensions || column_info&.limit
          raise ArgumentError, "Unknown dimensions" unless cast_dimensions
          quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
        else
          raise ArgumentError, "Invalid precision"
        end
      end

      order = Utils.order(adapter, type, operator, quoted_attribute, query)

      # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
      # with normalized vectors:
      # cosine similarity = 1 - (euclidean distance)**2 / 2
      # cosine distance = 1 - cosine similarity
      # this transformation doesn't change the order, so only needed for select
      neighbor_distance =
        if distance == "cosine" && normalize_required
          "POWER(#{order}, 2) / 2.0"
        elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
          "(#{order}) * -1"
        else
          order
        end

      # for select, use column_names instead of * to account for ignored columns
      select_columns = select_values.any? ? [] : column_names
      select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
        .where.not(attribute_name => nil)
        .reorder(Arel.sql(order))
    }

    def nearest_neighbors(attribute_name, **options)
      attribute_name = attribute_name.to_sym
      # important! check if neighbor attribute before accessing
      raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]

      self.class
        .where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
        .nearest_neighbors(attribute_name, self[attribute_name], **options)
    end
  end
end