Class: Vecsearch::GTETiny

Inherits:
Object
  • Object
show all
Defined in:
lib/vecsearch/gte_tiny.rb

Defined Under Namespace

Modules: Bert, CStdio

Constant Summary collapse

VENDOR =
File.expand_path('../../vendor', __dir__)
GTE_BIN =
File.expand_path('gte-tiny-q4_1.ggml.bin', VENDOR)
MAX_TOKENS =
512

Instance Method Summary collapse

Constructor Details

#initialize(fname = GTE_BIN) ⇒ GTETiny

Returns a new instance of GTETiny.



27
28
29
30
31
32
33
# File 'lib/vecsearch/gte_tiny.rb', line 27

def initialize(fname=GTE_BIN)
  suppress_streams do
    @ctx = Bert.bert_load_from_file(fname)
    @n_embd = Bert.bert_n_embd(@ctx)
    sleep(0.1)
  end
end

Instance Method Details

#encode(sentence, n_threads: 1) ⇒ Object



45
46
47
48
49
50
51
52
53
54
55
# File 'lib/vecsearch/gte_tiny.rb', line 45

def encode(sentence, n_threads: 1)
  # Encode the sentence into token embeddings
  token_embeddings = encode_batch([sentence], n_threads: 1)

  # Pool the token embeddings into a sentence embedding
  # For simplicity, we'll use an attention mask of all ones
  attention_mask = Array.new(token_embeddings.first.length, 1)
  # sentence_embedding = mean_pooling(token_embeddings, attention_mask)
  # sentence_embedding
  token_embeddings
end

#encode_batch(input, n_threads: 1) ⇒ Object



57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/vecsearch/gte_tiny.rb', line 57

def encode_batch(input, n_threads: 1)
  # Create an array of pointers to the input strings
  input_ptrs = input.map { |str| FFI::MemoryPointer.from_string(str) }

  # Create a pointer to the array of input pointers
  input_ptrs_ptr = FFI::MemoryPointer.new(:pointer, input_ptrs.length)
  input_ptrs_ptr.write_array_of_pointer(input_ptrs)

  # Create an output buffer for each input string
  output_ptrs = input.map { FFI::MemoryPointer.new(:float, @n_embd) }

  # Create a pointer to the array of output pointers
  output_ptrs_ptr = FFI::MemoryPointer.new(:pointer, output_ptrs.length)
  output_ptrs_ptr.write_array_of_pointer(output_ptrs)

  Bert.bert_encode_batch(@ctx, n_threads, MAX_TOKENS, input.length, input_ptrs_ptr, output_ptrs_ptr)

  # Convert the output buffers to Ruby arrays
  output = output_ptrs.map { |ptr| ptr.read_array_of_float(@n_embd) }

  output
end

#suppress_streamsObject



35
36
37
38
39
40
41
42
43
# File 'lib/vecsearch/gte_tiny.rb', line 35

def suppress_streams
  prev_stdout = STDOUT.dup
  STDOUT.reopen("/dev/null", "w")
  STDOUT.sync = true
  yield
ensure
  CStdio.fflush(nil) # Regular STDOUT.flush doesn't do it.
  STDOUT.reopen(prev_stdout)
end