Class: ConfusionMatrix

Inherits:
Object
  • Object
show all
Defined in:
lib/confusion_matrix.rb

Overview

Instances of this class hold the confusion matrix information. The object is designed to be called incrementally, as results are received. At any point, statistics may be obtained from the current results.

A two-label confusion matrix example is:

Observed        Observed      | 
Positive        Negative      | Predicted
------------------------------+------------
    a               b         | Positive
    c               d         | Negative

Statistical methods will be described with reference to this example.

Instance Method Summary collapse

Constructor Details

#initialize(*labels) ⇒ ConfusionMatrix

Creates a new, empty instance of a confusion matrix.

labels

a list of strings or labels. If provided, the first label is used as a default label, and all method calls must use one of the pre-defined labels.

Raises an ArgumentError if there are not at least two unique labels, when provided.



24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/confusion_matrix.rb', line 24

def initialize(*labels)
  @matrix = {}
  @labels = labels.uniq
  if @labels.size == 1
    raise ArgumentError.new("If labels are provided, there must be at least two.")
  else # preset the matrix Hash
    @labels.each do |predefined|
      @matrix[predefined] = {}
      @labels.each do |observed|
        @matrix[predefined][observed] = 0
      end
    end
  end
end

Instance Method Details

#add_for(predicted, observed, n = 1) ⇒ Object

Adds one result to the matrix for a given (predicted, observed) pair of labels.

If the matrix was given a pre-defined list of labels on construction, then these given labels must be from the pre-defined list. If no pre-defined list of labels was used in constructing the matrix, then labels will be added to matrix. Labels may be any hashable value, although ideally they are strings or symbols.



81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# File 'lib/confusion_matrix.rb', line 81

def add_for(predicted, observed, n = 1)
  validate_label predicted, observed
  unless @matrix.has_key?(predicted)
    @matrix[predicted] = {}
  end
  observations = @matrix[predicted]
  unless observations.has_key?(observed)
    observations[observed] = 0
  end

  unless n.class == Integer and n.positive?
    raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
  end

  @matrix[predicted][observed] += n
end

#count_for(predicted, observed) ⇒ Object

Returns the count for a given (predicted, observed) pair.

cm = ConfusionMatrix.new
cm.add_for(:pos, :neg)
cm.count_for(:pos, :neg) # => 1


68
69
70
71
72
# File 'lib/confusion_matrix.rb', line 68

def count_for(predicted, observed)
  validate_label predicted, observed
  observations = @matrix.fetch(predicted, {})
  observations.fetch(observed, 0)
end

#f_measure(label = @labels.first) ⇒ Object

The F-measure for a given label is the harmonic mean of the precision and recall for that label.

F = 2*(precision*recall)/(precision+recall)



152
153
154
155
# File 'lib/confusion_matrix.rb', line 152

def f_measure(label = @labels.first)
  validate_label label
  2*precision(label)*recall(label)/(precision(label) + recall(label))
end

#false_negative(label = @labels.first) ⇒ Object

Returns the number of observations of the given label which are incorrect.

For example matrix, false_negative(:positive) is b



102
103
104
105
106
107
108
109
110
111
112
113
114
# File 'lib/confusion_matrix.rb', line 102

def false_negative(label = @labels.first)
  validate_label label
  observations = @matrix.fetch(label, {})
  total = 0

  observations.each_pair do |key, count|
    if key != label 
      total += count
    end
  end

  total
end

#false_positive(label = @labels.first) ⇒ Object

Returns the number of observations incorrect with the given label.

For example matrix, false_positive(:positive) is c.



120
121
122
123
124
125
126
127
128
129
130
131
# File 'lib/confusion_matrix.rb', line 120

def false_positive(label = @labels.first)
  validate_label label
  total = 0

  @matrix.each_pair do |key, observations|
    if key != label
      total += observations.fetch(label, 0)
    end
  end

  total
end

#false_rate(label = @labels.first) ⇒ Object

The false rate for a given label is the proportion of observations incorrect for that label, out of all those observations not originally of that label.

For example matrix, false_rate(:positive) is c/(c+d).



139
140
141
142
143
144
145
# File 'lib/confusion_matrix.rb', line 139

def false_rate(label = @labels.first)
  validate_label label
  fp = false_positive(label)
  tn = true_negative(label)

  divide(fp, fp+tn)
end

#geometric_meanObject

The geometric mean is the nth-root of the product of the true_rate for each label.

For example:

  • a1 = a/(a+b)

  • a2 = d/(c+d)

  • geometric mean = Math.sqrt(a1*a2)



165
166
167
168
169
170
171
172
173
# File 'lib/confusion_matrix.rb', line 165

def geometric_mean
  product = 1

  @matrix.each_key do |key|
    product *= true_rate(key)
  end

  product**(1.0/@matrix.size)
end

#kappa(label = @labels.first) ⇒ Object

The Kappa statistic compares the observed accuracy with an expected accuracy.



178
179
180
181
182
183
184
185
186
187
188
189
190
# File 'lib/confusion_matrix.rb', line 178

def kappa(label = @labels.first)
  validate_label label
  tp = true_positive(label)
  fn = false_negative(label)
  fp = false_positive(label)
  tn = true_negative(label)
  total = tp+fn+fp+tn

  total_accuracy = divide(tp+tn, tp+tn+fp+fn)
  random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)

  divide(total_accuracy - random_accuracy, 1 - random_accuracy)
end

#labelsObject

Returns a list of labels used in the matrix.

cm = ConfusionMatrix.new
cm.add_for(:pos, :neg)
cm.labels # => [:neg, :pos]


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# File 'lib/confusion_matrix.rb', line 45

def labels
  if @labels.size >= 2 # if we defined some labels, return them
    @labels
  else
    result = []

    @matrix.each_pair do |key, observed|
      result << key
      observed.each_key do |key|
        result << key
      end
    end

    result.uniq.sort
  end
end

#matthews_correlation(label = @labels.first) ⇒ Object

Matthews Correlation Coefficient is a measure of the quality of binary classifications.

For example matrix, mathews_correlation(:positive) is (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b)).



198
199
200
201
202
203
204
205
206
# File 'lib/confusion_matrix.rb', line 198

def matthews_correlation(label = @labels.first)
  validate_label label
  tp = true_positive(label)
  fn = false_negative(label)
  fp = false_positive(label)
  tn = true_negative(label)

  divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
end

#overall_accuracyObject

The overall accuracy is the proportion of observations which are correctly labelled.

For example matrix, overall_accuracy is (a+d)/(a+b+c+d).



213
214
215
216
217
218
219
220
221
# File 'lib/confusion_matrix.rb', line 213

def overall_accuracy
  total_correct = 0

  @matrix.each_pair do |key, observations|
    total_correct += true_positive(key)
  end

  divide(total_correct, total)
end

#precision(label = @labels.first) ⇒ Object

The precision for a given label is the proportion of observations observed as that label which are correct.

For example matrix, precision(:positive) is a/(a+c).



228
229
230
231
232
233
234
# File 'lib/confusion_matrix.rb', line 228

def precision(label = @labels.first)
  validate_label label
  tp = true_positive(label)
  fp = false_positive(label)

  divide(tp, tp+fp)
end

#prevalence(label = @labels.first) ⇒ Object

The prevalence for a given label is the proportion of observations which were observed as of that label, out of the total.

For example matrix, prevalence(:positive) is (a+c)/(a+b+c+d).



241
242
243
244
245
246
247
248
249
250
# File 'lib/confusion_matrix.rb', line 241

def prevalence(label = @labels.first)
  validate_label label
  tp = true_positive(label)
  fn = false_negative(label)
  fp = false_positive(label)
  tn = true_negative(label)
  total = tp+fn+fp+tn

  divide(tp+fn, total)
end

#recall(label = @labels.first) ⇒ Object

The recall is another name for the true rate.



254
255
256
257
# File 'lib/confusion_matrix.rb', line 254

def recall(label = @labels.first)
  validate_label label
  true_rate(label)
end

#sensitivity(label = @labels.first) ⇒ Object

Sensitivity is another name for the true rate.



261
262
263
264
# File 'lib/confusion_matrix.rb', line 261

def sensitivity(label = @labels.first)
  validate_label label
  true_rate(label)
end

#specificity(label = @labels.first) ⇒ Object

The specificity for a given label is 1 - false_rate(label)

In two-class case, specificity = 1 - false_positive_rate



270
271
272
273
# File 'lib/confusion_matrix.rb', line 270

def specificity(label = @labels.first)
  validate_label label
  1-false_rate(label)
end

#to_sObject

Returns the table in a string format, representing the entries as a printable table.



278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# File 'lib/confusion_matrix.rb', line 278

def to_s
  ls = labels
  result = ""

  title_line = "Observed " 
  label_line = ""
  ls.each { |l| label_line << "#{l} " }
  label_line << " " while label_line.size < title_line.size
  title_line << " " while title_line.size < label_line.size
  result << title_line << "|\n" << label_line << "| Predicted\n"
  result << "-"*title_line.size << "+----------\n"

  ls.each do |l|
    count_line = ""
    ls.each_with_index do |m, i|
      count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
    end
    result << count_line.ljust(title_line.size) << "| #{l}\n"
  end

  result
end

#totalObject

Returns the total number of observations referenced in the matrix.

For example matrix, total is a+b+c+d.



305
306
307
308
309
310
311
312
313
314
315
# File 'lib/confusion_matrix.rb', line 305

def total
  total = 0

  @matrix.each_value do |observations|
    observations.each_value do |count|
      total += count
    end
  end

  total
end

#true_negative(label = @labels.first) ⇒ Object

Returns the number of observations NOT of the given label which are correct.

For example matrix, true_negative(:positive) is d.



321
322
323
324
325
326
327
328
329
330
331
332
# File 'lib/confusion_matrix.rb', line 321

def true_negative(label = @labels.first)
  validate_label label
  total = 0

  @matrix.each_pair do |key, observations|
    if key != label 
      total += observations.fetch(key, 0)
    end
  end

  total
end

#true_positive(label = @labels.first) ⇒ Object

Returns the number of observations of the given label which are correct.

For example matrix, true_positive(:positive) is a.



338
339
340
341
342
# File 'lib/confusion_matrix.rb', line 338

def true_positive(label = @labels.first)
  validate_label label
  observations = @matrix.fetch(label, {})
  observations.fetch(label, 0)
end

#true_rate(label = @labels.first) ⇒ Object

The true rate for a given label is the proportion of observations of that label which are correct.

For example matrix, true_rate(:positive) is a/(a+b).



349
350
351
352
353
354
355
# File 'lib/confusion_matrix.rb', line 349

def true_rate(label = @labels.first)
  validate_label label
  tp = true_positive(label)
  fn = false_negative(label)

  divide(tp, tp+fn)
end