Class: DecisionTree::ID3Tree
Instance Method Summary collapse
- #build_rules(tree = @tree) ⇒ Object
- #graph(filename) ⇒ Object
-
#id3_continuous(data, attributes, attribute) ⇒ Object
ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds).
-
#id3_discrete(data, attributes, attribute) ⇒ Object
ID3 for discrete label cases.
- #id3_train(data, attributes, default, used = {}) ⇒ Object
-
#initialize(attributes, data, default, type) ⇒ ID3Tree
constructor
A new instance of ID3Tree.
- #predict(test) ⇒ Object
- #ruleset ⇒ Object
- #train(data = @data, attributes = @attributes, default = @default) ⇒ Object
Constructor Details
#initialize(attributes, data, default, type) ⇒ ID3Tree
Returns a new instance of ID3Tree.
45 46 47 48 |
# File 'lib/decisiontree/id3_tree.rb', line 45 def initialize(attributes, data, default, type) @used, @tree, @type = {}, {}, type @data, @attributes, @default = data, attributes, default end |
Instance Method Details
#build_rules(tree = @tree) ⇒ Object
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# File 'lib/decisiontree/id3_tree.rb', line 139 def build_rules(tree=@tree) attr = tree.to_a.first cases = attr[1].to_a rules = [] cases.each do |c,child| if child.is_a?(Hash) then build_rules(child).each do |r| r2 = r.clone r2.premises.unshift([attr.first, c]) rules << r2 end else rules << Rule.new(@attributes, [[attr.first, c]], child) end end rules end |
#graph(filename) ⇒ Object
128 129 130 131 |
# File 'lib/decisiontree/id3_tree.rb', line 128 def graph(filename) dgp = DotGraphPrinter.new(build_tree) dgp.write_to_file("#{filename}.png", "png") end |
#id3_continuous(data, attributes, attribute) ⇒ Object
ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# File 'lib/decisiontree/id3_tree.rb', line 96 def id3_continuous(data, attributes, attribute) values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, [] return [-1, -1] if values.size == 1 values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) } thresholds.pop #thresholds -= used[attribute] if used.has_key? attribute gain = thresholds.collect { |threshold| sp = data.partition { |d| d[attributes.index(attribute)] >= threshold } pos = (sp[0].size).to_f / data.size neg = (sp[1].size).to_f / data.size [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold] }.max { |a,b| a[0] <=> b[0] } return [-1, -1] if gain.size == 0 gain end |
#id3_discrete(data, attributes, attribute) ⇒ Object
ID3 for discrete label cases
116 117 118 119 120 121 122 |
# File 'lib/decisiontree/id3_tree.rb', line 116 def id3_discrete(data, attributes, attribute) values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i } [data.classification.entropy - remainder, attributes.index(attribute)] end |
#id3_train(data, attributes, default, used = {}) ⇒ Object
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# File 'lib/decisiontree/id3_tree.rb', line 59 def id3_train(data, attributes, default, used={}) # Choose a fitness algorithm case @type when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} end return default if data.empty? # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } best = Node.new(attributes[performance.index(max)], max[1], max[0]) best.threshold = nil if @type == :discrete @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] tree, l = {best => {}}, ['>=', '<'] case @type when :continuous data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i| tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) } when :discrete values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } } partitions.each_with_index { |examples, i| tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness) } end tree end |
#predict(test) ⇒ Object
124 125 126 |
# File 'lib/decisiontree/id3_tree.rb', line 124 def predict(test) return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)), 1 end |
#ruleset ⇒ Object
133 134 135 136 137 |
# File 'lib/decisiontree/id3_tree.rb', line 133 def ruleset rs = Ruleset.new(@attributes, @data, @default, @type) rs.rules = build_rules rs end |
#train(data = @data, attributes = @attributes, default = @default) ⇒ Object
50 51 52 53 54 55 56 57 |
# File 'lib/decisiontree/id3_tree.rb', line 50 def train(data=@data, attributes=@attributes, default=@default) initialize(attributes, data, default, @type) # Remove samples with same attributes leaving most common classification data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]} @tree = id3_train(data2, attributes, default) end |