Module: Rblearn::CrossValidation

Defined in:
lib/rblearn/CrossValidation.rb

Defined Under Namespace

Classes: KFold

Class Method Summary collapse

Class Method Details

.train_test_split(x, y, test_size = 0.33) ⇒ Object

x, y: Narray object We slice a matrix by x[Array<Integer>, true]



6
7
8
9
10
11
12
13
14
# File 'lib/rblearn/CrossValidation.rb', line 6

def self.train_test_split(x, y, test_size=0.33)
  doc_size = x.shape[0]
  random_indices = (0...doc_size).to_a.shuffle
  endpoint = (doc_size * test_size).to_i
  train_indices = random_indices[endpoint..-1]
  test_indices = random_indices[0...endpoint]

  return [x[train_indices, true], y[train_indices, true], x[test_indices, true], y[test_indices, true]]
end