29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
# File 'lib/train_test_split.rb', line 29
def self.train_validation_test_split(total_data_set, validation_size = 0.15, test_size = 0.10)
if test_size > 1.0
test_size = 1.0
elsif validation_size > 1.0
validation_size = 1.0
elsif validation_size < 0
validation_size = 0.0
elsif test_size < 0
test_size = 0.0
end
test_set_count = (total_data_set.length * test_size).floor
validation_size_count = (total_data_set.length * validation_size).floor
if test_set_count == 0
raise StandardError, "Test size resulted in a test set of 0. Increase the test size."
elsif test_set_count == total_data_set.length
raise StandardError, "Test size resulted in a training set of 0. Decrease the test size."
end
if validation_size_count == 0
raise StandardError, "validation size resulted in a test set of 0. Increase the validation data size."
elsif test_set_count == total_data_set.length
raise StandardError, "validation size resulted in a training set of 0. Decrease the validation data size."
end
total_data_set.shuffle!
val_count = test_set_count + validation_size_count
test_set = total_data_set[0..test_set_count]
validation_set = total_data_set[test_set_count+1..val_count]
training_set = total_data_set[val_count+1..total_data_set.length]
training_set_Y = training_set.map(&:last)
training_set.map{|row| row.pop}
test_set_Y = test_set.map(&:last)
test_set.map{|row| row.pop}
validation_set_Y = validation_set.map(&:last)
validation_set.map{|row| row.pop}
return training_set, training_set_Y, validation_set, validation_set_Y, test_set, test_set_Y
end
|