Class: Rumale::NearestNeighbors::KNeighborsRegressor
- Inherits:
-
Object
- Object
- Rumale::NearestNeighbors::KNeighborsRegressor
- Includes:
- Base::BaseEstimator, Base::Regressor
- Defined in:
- lib/rumale/nearest_neighbors/k_neighbors_regressor.rb
Overview
KNeighborsRegressor is a class that implements the regressor with the k-nearest neighbors rule. The current implementation uses the Euclidean distance for finding the neighbors.
Instance Attribute Summary collapse
-
#prototypes ⇒ Numo::DFloat
readonly
Return the prototypes for the nearest neighbor regressor.
-
#values ⇒ Numo::DFloat
readonly
Return the values of the prototypes.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#fit(x, y) ⇒ KNeighborsRegressor
Fit the model with given training data.
-
#initialize(n_neighbors: 5) ⇒ KNeighborsRegressor
constructor
Create a new regressor with the nearest neighbor rule.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
Methods included from Base::Regressor
Constructor Details
#initialize(n_neighbors: 5) ⇒ KNeighborsRegressor
Create a new regressor with the nearest neighbor rule.
32 33 34 35 36 37 38 39 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 32 def initialize(n_neighbors: 5) check_params_integer(n_neighbors: n_neighbors) check_params_positive(n_neighbors: n_neighbors) @params = {} @params[:n_neighbors] = n_neighbors @prototypes = nil @values = nil end |
Instance Attribute Details
#prototypes ⇒ Numo::DFloat (readonly)
Return the prototypes for the nearest neighbor regressor.
23 24 25 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 23 def prototypes @prototypes end |
#values ⇒ Numo::DFloat (readonly)
Return the values of the prototypes
27 28 29 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 27 def values @values end |
Instance Method Details
#fit(x, y) ⇒ KNeighborsRegressor
Fit the model with given training data.
46 47 48 49 50 51 52 53 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 46 def fit(x, y) check_sample_array(x) check_tvalue_array(y) check_sample_tvalue_size(x, y) @prototypes = x.dup @values = y.dup self end |
#marshal_dump ⇒ Hash
Dump marshal data.
77 78 79 80 81 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 77 def marshal_dump { params: @params, prototypes: @prototypes, values: @values } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
85 86 87 88 89 90 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 85 def marshal_load(obj) @params = obj[:params] @prototypes = obj[:prototypes] @values = obj[:values] nil end |
#predict(x) ⇒ Numo::DFloat
Predict values for samples.
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# File 'lib/rumale/nearest_neighbors/k_neighbors_regressor.rb', line 59 def predict(x) check_sample_array(x) # Initialize some variables. n_samples, = x.shape n_prototypes, n_outputs = @values.shape n_neighbors = [@params[:n_neighbors], n_prototypes].min # Calculate distance matrix. distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes) # Predict values for the given samples. predicted_values = Array.new(n_samples) do |n| neighbor_ids = distance_matrix[n, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors] n_outputs.nil? ? @values[neighbor_ids].mean : @values[neighbor_ids, true].mean(0).to_a end Numo::DFloat[*predicted_values] end |