Class: ConfusionMatrix
- Inherits:
-
Object
- Object
- ConfusionMatrix
- Defined in:
- lib/confusion_matrix.rb
Overview
Instances of this class hold the confusion matrix information. The object is designed to be called incrementally, as results are received. At any point, statistics may be obtained from the current results.
A two-label confusion matrix example is:
Observed Observed |
Positive Negative | Predicted
------------------------------+------------
a b | Positive
c d | Negative
Statistical methods will be described with reference to this example.
Instance Method Summary collapse
-
#add_for(predicted, observed, n = 1) ⇒ Object
Adds one result to the matrix for a given (predicted, observed) pair of labels.
-
#count_for(predicted, observed) ⇒ Object
Returns the count for a given (predicted, observed) pair.
-
#f_measure(label = @labels.first) ⇒ Object
The F-measure for a given label is the harmonic mean of the precision and recall for that label.
-
#false_negative(label = @labels.first) ⇒ Object
Returns the number of observations of the given label which are incorrect.
-
#false_positive(label = @labels.first) ⇒ Object
Returns the number of observations incorrect with the given label.
-
#false_rate(label = @labels.first) ⇒ Object
The false rate for a given label is the proportion of observations incorrect for that label, out of all those observations not originally of that label.
-
#geometric_mean ⇒ Object
The geometric mean is the nth-root of the product of the true_rate for each label.
-
#initialize(*labels) ⇒ ConfusionMatrix
constructor
Creates a new, empty instance of a confusion matrix.
-
#kappa(label = @labels.first) ⇒ Object
The Kappa statistic compares the observed accuracy with an expected accuracy.
-
#labels ⇒ Object
Returns a list of labels used in the matrix.
-
#matthews_correlation(label = @labels.first) ⇒ Object
Matthews Correlation Coefficient is a measure of the quality of binary classifications.
-
#overall_accuracy ⇒ Object
The overall accuracy is the proportion of observations which are correctly labelled.
-
#precision(label = @labels.first) ⇒ Object
The precision for a given label is the proportion of observations observed as that label which are correct.
-
#prevalence(label = @labels.first) ⇒ Object
The prevalence for a given label is the proportion of observations which were observed as of that label, out of the total.
-
#recall(label = @labels.first) ⇒ Object
The recall is another name for the true rate.
-
#sensitivity(label = @labels.first) ⇒ Object
Sensitivity is another name for the true rate.
-
#specificity(label = @labels.first) ⇒ Object
The specificity for a given label is 1 - false_rate(label).
-
#to_s ⇒ Object
Returns the table in a string format, representing the entries as a printable table.
-
#total ⇒ Object
Returns the total number of observations referenced in the matrix.
-
#true_negative(label = @labels.first) ⇒ Object
Returns the number of observations NOT of the given label which are correct.
-
#true_positive(label = @labels.first) ⇒ Object
Returns the number of observations of the given label which are correct.
-
#true_rate(label = @labels.first) ⇒ Object
The true rate for a given label is the proportion of observations of that label which are correct.
Constructor Details
#initialize(*labels) ⇒ ConfusionMatrix
Creates a new, empty instance of a confusion matrix.
- labels
-
a list of strings or labels. If provided, the first label is used as a default label, and all method calls must use one of the pre-defined labels.
Raises an ArgumentError if there are not at least two unique labels, when provided.
24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# File 'lib/confusion_matrix.rb', line 24 def initialize(*labels) @matrix = {} @labels = labels.uniq if @labels.size == 1 raise ArgumentError.new("If labels are provided, there must be at least two.") else # preset the matrix Hash @labels.each do |predefined| @matrix[predefined] = {} @labels.each do |observed| @matrix[predefined][observed] = 0 end end end end |
Instance Method Details
#add_for(predicted, observed, n = 1) ⇒ Object
Adds one result to the matrix for a given (predicted, observed) pair of labels.
If the matrix was given a pre-defined list of labels on construction, then these given labels must be from the pre-defined list. If no pre-defined list of labels was used in constructing the matrix, then labels will be added to matrix. Labels may be any hashable value, although ideally they are strings or symbols.
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# File 'lib/confusion_matrix.rb', line 81 def add_for(predicted, observed, n = 1) validate_label predicted, observed unless @matrix.has_key?(predicted) @matrix[predicted] = {} end observations = @matrix[predicted] unless observations.has_key?(observed) observations[observed] = 0 end unless n.class == Integer and n.positive? raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}") end @matrix[predicted][observed] += n end |
#count_for(predicted, observed) ⇒ Object
Returns the count for a given (predicted, observed) pair.
cm = ConfusionMatrix.new
cm.add_for(:pos, :neg)
cm.count_for(:pos, :neg) # => 1
68 69 70 71 72 |
# File 'lib/confusion_matrix.rb', line 68 def count_for(predicted, observed) validate_label predicted, observed observations = @matrix.fetch(predicted, {}) observations.fetch(observed, 0) end |
#f_measure(label = @labels.first) ⇒ Object
The F-measure for a given label is the harmonic mean of the precision and recall for that label.
F = 2*(precision*recall)/(precision+recall)
152 153 154 155 |
# File 'lib/confusion_matrix.rb', line 152 def f_measure(label = @labels.first) validate_label label 2*precision(label)*recall(label)/(precision(label) + recall(label)) end |
#false_negative(label = @labels.first) ⇒ Object
Returns the number of observations of the given label which are incorrect.
For example matrix, false_negative(:positive) is b
102 103 104 105 106 107 108 109 110 111 112 113 114 |
# File 'lib/confusion_matrix.rb', line 102 def false_negative(label = @labels.first) validate_label label observations = @matrix.fetch(label, {}) total = 0 observations.each_pair do |key, count| if key != label total += count end end total end |
#false_positive(label = @labels.first) ⇒ Object
Returns the number of observations incorrect with the given label.
For example matrix, false_positive(:positive) is c.
120 121 122 123 124 125 126 127 128 129 130 131 |
# File 'lib/confusion_matrix.rb', line 120 def false_positive(label = @labels.first) validate_label label total = 0 @matrix.each_pair do |key, observations| if key != label total += observations.fetch(label, 0) end end total end |
#false_rate(label = @labels.first) ⇒ Object
The false rate for a given label is the proportion of observations incorrect for that label, out of all those observations not originally of that label.
For example matrix, false_rate(:positive) is c/(c+d).
139 140 141 142 143 144 145 |
# File 'lib/confusion_matrix.rb', line 139 def false_rate(label = @labels.first) validate_label label fp = false_positive(label) tn = true_negative(label) divide(fp, fp+tn) end |
#geometric_mean ⇒ Object
The geometric mean is the nth-root of the product of the true_rate for each label.
For example:
-
a1 = a/(a+b)
-
a2 = d/(c+d)
-
geometric mean = Math.sqrt(a1*a2)
165 166 167 168 169 170 171 172 173 |
# File 'lib/confusion_matrix.rb', line 165 def geometric_mean product = 1 @matrix.each_key do |key| product *= true_rate(key) end product**(1.0/@matrix.size) end |
#kappa(label = @labels.first) ⇒ Object
The Kappa statistic compares the observed accuracy with an expected accuracy.
178 179 180 181 182 183 184 185 186 187 188 189 190 |
# File 'lib/confusion_matrix.rb', line 178 def kappa(label = @labels.first) validate_label label tp = true_positive(label) fn = false_negative(label) fp = false_positive(label) tn = true_negative(label) total = tp+fn+fp+tn total_accuracy = divide(tp+tn, tp+tn+fp+fn) random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total) divide(total_accuracy - random_accuracy, 1 - random_accuracy) end |
#labels ⇒ Object
Returns a list of labels used in the matrix.
cm = ConfusionMatrix.new
cm.add_for(:pos, :neg)
cm.labels # => [:neg, :pos]
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# File 'lib/confusion_matrix.rb', line 45 def labels if @labels.size >= 2 # if we defined some labels, return them @labels else result = [] @matrix.each_pair do |key, observed| result << key observed.each_key do |key| result << key end end result.uniq.sort end end |
#matthews_correlation(label = @labels.first) ⇒ Object
Matthews Correlation Coefficient is a measure of the quality of binary classifications.
For example matrix, mathews_correlation(:positive) is (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b)).
198 199 200 201 202 203 204 205 206 |
# File 'lib/confusion_matrix.rb', line 198 def matthews_correlation(label = @labels.first) validate_label label tp = true_positive(label) fn = false_negative(label) fp = false_positive(label) tn = true_negative(label) divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))) end |
#overall_accuracy ⇒ Object
The overall accuracy is the proportion of observations which are correctly labelled.
For example matrix, overall_accuracy is (a+d)/(a+b+c+d).
213 214 215 216 217 218 219 220 221 |
# File 'lib/confusion_matrix.rb', line 213 def overall_accuracy total_correct = 0 @matrix.each_pair do |key, observations| total_correct += true_positive(key) end divide(total_correct, total) end |
#precision(label = @labels.first) ⇒ Object
The precision for a given label is the proportion of observations observed as that label which are correct.
For example matrix, precision(:positive) is a/(a+c).
228 229 230 231 232 233 234 |
# File 'lib/confusion_matrix.rb', line 228 def precision(label = @labels.first) validate_label label tp = true_positive(label) fp = false_positive(label) divide(tp, tp+fp) end |
#prevalence(label = @labels.first) ⇒ Object
The prevalence for a given label is the proportion of observations which were observed as of that label, out of the total.
For example matrix, prevalence(:positive) is (a+c)/(a+b+c+d).
241 242 243 244 245 246 247 248 249 250 |
# File 'lib/confusion_matrix.rb', line 241 def prevalence(label = @labels.first) validate_label label tp = true_positive(label) fn = false_negative(label) fp = false_positive(label) tn = true_negative(label) total = tp+fn+fp+tn divide(tp+fn, total) end |
#recall(label = @labels.first) ⇒ Object
The recall is another name for the true rate.
254 255 256 257 |
# File 'lib/confusion_matrix.rb', line 254 def recall(label = @labels.first) validate_label label true_rate(label) end |
#sensitivity(label = @labels.first) ⇒ Object
Sensitivity is another name for the true rate.
261 262 263 264 |
# File 'lib/confusion_matrix.rb', line 261 def sensitivity(label = @labels.first) validate_label label true_rate(label) end |
#specificity(label = @labels.first) ⇒ Object
The specificity for a given label is 1 - false_rate(label)
In two-class case, specificity = 1 - false_positive_rate
270 271 272 273 |
# File 'lib/confusion_matrix.rb', line 270 def specificity(label = @labels.first) validate_label label 1-false_rate(label) end |
#to_s ⇒ Object
Returns the table in a string format, representing the entries as a printable table.
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
# File 'lib/confusion_matrix.rb', line 278 def to_s ls = labels result = "" title_line = "Observed " label_line = "" ls.each { |l| label_line << "#{l} " } label_line << " " while label_line.size < title_line.size title_line << " " while title_line.size < label_line.size result << title_line << "|\n" << label_line << "| Predicted\n" result << "-"*title_line.size << "+----------\n" ls.each do |l| count_line = "" ls.each_with_index do |m, i| count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " " end result << count_line.ljust(title_line.size) << "| #{l}\n" end result end |
#total ⇒ Object
Returns the total number of observations referenced in the matrix.
For example matrix, total is a+b+c+d.
305 306 307 308 309 310 311 312 313 314 315 |
# File 'lib/confusion_matrix.rb', line 305 def total total = 0 @matrix.each_value do |observations| observations.each_value do |count| total += count end end total end |
#true_negative(label = @labels.first) ⇒ Object
Returns the number of observations NOT of the given label which are correct.
For example matrix, true_negative(:positive) is d.
321 322 323 324 325 326 327 328 329 330 331 332 |
# File 'lib/confusion_matrix.rb', line 321 def true_negative(label = @labels.first) validate_label label total = 0 @matrix.each_pair do |key, observations| if key != label total += observations.fetch(key, 0) end end total end |
#true_positive(label = @labels.first) ⇒ Object
Returns the number of observations of the given label which are correct.
For example matrix, true_positive(:positive) is a.
338 339 340 341 342 |
# File 'lib/confusion_matrix.rb', line 338 def true_positive(label = @labels.first) validate_label label observations = @matrix.fetch(label, {}) observations.fetch(label, 0) end |
#true_rate(label = @labels.first) ⇒ Object
The true rate for a given label is the proportion of observations of that label which are correct.
For example matrix, true_rate(:positive) is a/(a+b).
349 350 351 352 353 354 355 |
# File 'lib/confusion_matrix.rb', line 349 def true_rate(label = @labels.first) validate_label label tp = true_positive(label) fn = false_negative(label) divide(tp, tp+fn) end |