Module: Rblearn::CrossValidation
- Defined in:
- lib/rblearn/CrossValidation.rb
Defined Under Namespace
Classes: KFold
Class Method Summary collapse
-
.train_test_split(x, y, test_size = 0.33) ⇒ Object
x, y: Narray object We slice a matrix by x[Array<Integer>, true].
Class Method Details
.train_test_split(x, y, test_size = 0.33) ⇒ Object
x, y: Narray object We slice a matrix by x[Array<Integer>, true]
6 7 8 9 10 11 12 13 14 |
# File 'lib/rblearn/CrossValidation.rb', line 6 def self.train_test_split(x, y, test_size=0.33) doc_size = x.shape[0] random_indices = (0...doc_size).to_a.shuffle endpoint = (doc_size * test_size).to_i train_indices = random_indices[endpoint..-1] test_indices = random_indices[0...endpoint] return [x[train_indices, true], y[train_indices, true], x[test_indices, true], y[test_indices, true]] end |