Class: LinearRegression

Inherits:
Object
  • Object
show all
Includes:
Tools::ClassifierMethods, Tools::DataMethods
Defined in:
lib/rubyml/linear_regression.rb

Overview

The linear regression class with customizable number of folds for K-fold cross validation.

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods included from Tools::ClassifierMethods

#correct_count, #generate_folds, #generate_test_set, #generate_train_set, #handle_epsilon, #training_accuracy

Methods included from Tools::DataMethods

#bias_trick, #load_data, #mat_to_array, #plot, #plot_function, #separate_data

Constructor Details

#initialize(precision = 3, folds = 5) ⇒ LinearRegression

Returns a new instance of LinearRegression.



11
12
13
14
15
# File 'lib/rubyml/linear_regression.rb', line 11

def initialize(precision = 3, folds = 5)
  @precision = precision
  @epsilon = 2.0
  @folds = folds
end

Instance Attribute Details

#accuracyObject (readonly)

Returns the value of attribute accuracy.



9
10
11
# File 'lib/rubyml/linear_regression.rb', line 9

def accuracy
  @accuracy
end

#foldsObject (readonly)

Returns the value of attribute folds.



9
10
11
# File 'lib/rubyml/linear_regression.rb', line 9

def folds
  @folds
end

#precisionObject (readonly)

Returns the value of attribute precision.



9
10
11
# File 'lib/rubyml/linear_regression.rb', line 9

def precision
  @precision
end

#thetaObject (readonly)

Returns the value of attribute theta.



9
10
11
# File 'lib/rubyml/linear_regression.rb', line 9

def theta
  @theta
end

Instance Method Details

#fit(x, y) ⇒ Object



17
18
19
20
21
# File 'lib/rubyml/linear_regression.rb', line 17

def fit(x, y)
  x_mat = bias_trick(x)
  @theta = ((x_mat.t * x_mat).inv * x_mat.t) * y
  @theta = @theta.collect { |e| e.round(@precision) }
end

#predict(x) ⇒ Object



23
24
25
26
# File 'lib/rubyml/linear_regression.rb', line 23

def predict(x)
  x_mat = bias_trick(x)
  (x_mat * @theta).collect { |e| e.round(@precision) }
end

#visualize(x, y) ⇒ Object



28
29
30
31
32
33
# File 'lib/rubyml/linear_regression.rb', line 28

def visualize(x, y)
  x = mat_to_array(x)
  y = mat_to_array(y)
  theta = mat_to_array(@theta)
  plot_function(x, y, theta)
end