Class: FastText::Classifier

Inherits:
Model
  • Object
show all
Defined in:
lib/fasttext/classifier.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  lr: 0.1,
  lr_update_rate: 100,
  dim: 100,
  ws: 5,
  epoch: 5,
  min_count: 1,
  min_count_label: 0,
  neg: 5,
  word_ngrams: 1,
  loss: "softmax",
  model: "supervised",
  bucket: 2000000,
  minn: 0,
  maxn: 0,
  thread: 3,
  t: 0.0001,
  label_prefix: "__label__",
  verbose: 2,
  pretrained_vectors: "",
  save_output: false,
  # seed: 0
}

Instance Method Summary collapse

Methods inherited from Model

#dimension, #initialize, #quantized?, #save_model, #sentence_vector, #subword_id, #subwords, #word_id, #word_vector, #words

Constructor Details

This class inherits a constructor from FastText::Model

Instance Method Details

#fit(x, y = nil) ⇒ Object



27
28
29
30
31
# File 'lib/fasttext/classifier.rb', line 27

def fit(x, y = nil)
  input = input_path(x, y)
  @m ||= Ext::Model.new
  m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised"))
end

#labels(include_freq: false) ⇒ Object



63
64
65
66
67
68
69
70
71
# File 'lib/fasttext/classifier.rb', line 63

def labels(include_freq: false)
  labels, freqs = m.labels
  labels.map! { |v| remove_prefix(v) }
  if include_freq
    labels.zip(freqs).to_h
  else
    labels
  end
end

#predict(text, k: 1, threshold: 0.0) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
# File 'lib/fasttext/classifier.rb', line 33

def predict(text, k: 1, threshold: 0.0)
  multiple = text.is_a?(Array)
  text = [text] unless multiple

  # TODO predict multiple in C++ for performance
  result =
    text.map do |t|
      m.predict(prep_text(t), k, threshold).map do |v|
        [remove_prefix(v[1]), v[0]]
      end.to_h
    end

  multiple ? result : result.first
end

#quantizeObject

TODO support options



59
60
61
# File 'lib/fasttext/classifier.rb', line 59

def quantize
  m.quantize({})
end

#test(x, y = nil, k: 1) ⇒ Object



48
49
50
51
52
53
54
55
56
# File 'lib/fasttext/classifier.rb', line 48

def test(x, y = nil, k: 1)
  input = input_path(x, y)
  res = m.test(input, k)
  {
    examples: res[0],
    precision: res[1],
    recall: res[2]
  }
end