Method: Nimbus::ClassificationTree#build_node

Defined in:
lib/nimbus/classification_tree.rb

#build_node(individuals_ids, y_hat) ⇒ Object

Creates a node by taking a random sample of the SNPs and computing the loss function for every split by SNP of that sample.

  • If SNP_min is the SNP with smaller loss function and it is < the loss function of the node, it splits the individuals sample in two:

(the average of the 0,1,2 values for the SNP_min in the individuals is computed, and they are splitted in [<=avg], [>avg]) then it builds these 2 new nodes.

  • Otherwise every individual in the node gets labeled with the average of the fenotype values of all of them.


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/nimbus/classification_tree.rb', line 40

def build_node(individuals_ids, y_hat)
  # General loss function value for the node
  individuals_count = individuals_ids.size
  return label_node(y_hat, individuals_ids) if individuals_count < @node_min_size
  node_loss_function = Nimbus::LossFunctions.gini_index individuals_ids, @id_to_fenotype, @classes

  # Finding the SNP that minimizes loss function
  snps = snps_random_sample
  min_loss, min_SNP, split, split_type, ginis = node_loss_function, nil, nil, nil, nil

  snps.each do |snp|
    individuals_split_by_snp_value, node_split_type = split_by_snp_avegare_value individuals_ids, snp
    y_hat_0 = Nimbus::LossFunctions.majority_class(individuals_split_by_snp_value[0], @id_to_fenotype, @classes)
    y_hat_1 = Nimbus::LossFunctions.majority_class(individuals_split_by_snp_value[1], @id_to_fenotype, @classes)

    gini_0 = Nimbus::LossFunctions.gini_index individuals_split_by_snp_value[0], @id_to_fenotype, @classes
    gini_1 = Nimbus::LossFunctions.gini_index individuals_split_by_snp_value[1], @id_to_fenotype, @classes
    loss_snp = (individuals_split_by_snp_value[0].size * gini_0 +
                individuals_split_by_snp_value[1].size * gini_1) / individuals_count

    min_loss, min_SNP, split, split_type, ginis = loss_snp, snp, individuals_split_by_snp_value, node_split_type, [y_hat_0, y_hat_1] if loss_snp < min_loss
  end
  return build_branch(min_SNP, split, split_type, ginis, y_hat) if min_loss < node_loss_function
  return label_node(y_hat, individuals_ids)
end