Method: Nimbus::ClassificationTree#build_node
- Defined in:
- lib/nimbus/classification_tree.rb
#build_node(individuals_ids, y_hat) ⇒ Object
Creates a node by taking a random sample of the SNPs and computing the loss function for every split by SNP of that sample.
-
If SNP_min is the SNP with smaller loss function and it is < the loss function of the node, it splits the individuals sample in two:
(the average of the 0,1,2 values for the SNP_min in the individuals is computed, and they are splitted in [<=avg], [>avg]) then it builds these 2 new nodes.
-
Otherwise every individual in the node gets labeled with the average of the fenotype values of all of them.
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# File 'lib/nimbus/classification_tree.rb', line 40 def build_node(individuals_ids, y_hat) # General loss function value for the node individuals_count = individuals_ids.size return label_node(y_hat, individuals_ids) if individuals_count < @node_min_size node_loss_function = Nimbus::LossFunctions.gini_index individuals_ids, @id_to_fenotype, @classes # Finding the SNP that minimizes loss function snps = snps_random_sample min_loss, min_SNP, split, split_type, ginis = node_loss_function, nil, nil, nil, nil snps.each do |snp| individuals_split_by_snp_value, node_split_type = split_by_snp_avegare_value individuals_ids, snp y_hat_0 = Nimbus::LossFunctions.majority_class(individuals_split_by_snp_value[0], @id_to_fenotype, @classes) y_hat_1 = Nimbus::LossFunctions.majority_class(individuals_split_by_snp_value[1], @id_to_fenotype, @classes) gini_0 = Nimbus::LossFunctions.gini_index individuals_split_by_snp_value[0], @id_to_fenotype, @classes gini_1 = Nimbus::LossFunctions.gini_index individuals_split_by_snp_value[1], @id_to_fenotype, @classes loss_snp = (individuals_split_by_snp_value[0].size * gini_0 + individuals_split_by_snp_value[1].size * gini_1) / individuals_count min_loss, min_SNP, split, split_type, ginis = loss_snp, snp, individuals_split_by_snp_value, node_split_type, [y_hat_0, y_hat_1] if loss_snp < min_loss end return build_branch(min_SNP, split, split_type, ginis, y_hat) if min_loss < node_loss_function return label_node(y_hat, individuals_ids) end |