Class: EMAlgorithm::MdGaussian

Inherits:
Model
  • Object
show all
Defined in:
lib/em_algorithm/models/md_gaussian.rb

Constant Summary

Constants inherited from Model

EMAlgorithm::Model::DIGIT

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods inherited from Model

#pdf, #value_distribution, #value_distribution_to_gnuplot

Constructor Details

#initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0]) ⇒ MdGaussian

Returns a new instance of MdGaussian.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# File 'lib/em_algorithm/models/md_gaussian.rb', line 5

def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
  # check mu
  if mu.class != GSL::Vector
    raise ArgumentError, "mu should be GSL::Vector."
  end
  @mu = mu
  # check sigma2
  if sigma2.class != GSL::Matrix
    raise ArgumentError, "sigma2 should be GSL::Matrix."
  elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
    raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
  end
  @sigma2 = sigma2
  @sqrt_sigma2_det = sqrt(@sigma2.det)
  @sigma2_invert = @sigma2.invert
end

Instance Attribute Details

#muObject

Returns the value of attribute mu.



3
4
5
# File 'lib/em_algorithm/models/md_gaussian.rb', line 3

def mu
  @mu
end

#sigma2Object

Returns the value of attribute sigma2.



3
4
5
# File 'lib/em_algorithm/models/md_gaussian.rb', line 3

def sigma2
  @sigma2
end

Instance Method Details

#probability_density_function(x) ⇒ Object



22
23
24
# File 'lib/em_algorithm/models/md_gaussian.rb', line 22

def probability_density_function(x)
  exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
end

#to_gnuplotObject



47
48
49
50
51
52
53
54
55
56
57
# File 'lib/em_algorithm/models/md_gaussian.rb', line 47

def to_gnuplot
  if @mu.size == 2
    # [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy,  s_y]] * [x - mu_x, y - mu_y]
    # = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
    sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
    xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
    "exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
  else
    "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
  end
end

#to_gnuplot_with_title(weight) ⇒ Object



59
60
61
62
63
64
65
# File 'lib/em_algorithm/models/md_gaussian.rb', line 59

def to_gnuplot_with_title(weight)
  if @mu.size == 2
    to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
  else
    "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
  end
end

#update_average!(data_array, temp_weight, temp_weight_per_datum) ⇒ Object



26
27
28
29
30
31
# File 'lib/em_algorithm/models/md_gaussian.rb', line 26

def update_average!(data_array, temp_weight, temp_weight_per_datum)
  data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
    sum + temp_weight_per_datum[di] * data_array[di]
  end
  @mu = data_sum / temp_weight
end

#update_parameters!(data_array, temp_weight, temp_weight_per_datum) ⇒ Object



42
43
44
45
# File 'lib/em_algorithm/models/md_gaussian.rb', line 42

def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
  update_average!(data_array, temp_weight, temp_weight_per_datum)
  update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
end

#update_sigma2!(data_array, temp_weight, temp_weight_per_datum) ⇒ Object



33
34
35
36
37
38
39
40
# File 'lib/em_algorithm/models/md_gaussian.rb', line 33

def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
  data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
    sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
  end
  @sigma2 = (data_sum / temp_weight)
  @sqrt_sigma2_det = sqrt(@sigma2.det)
  @sigma2_invert = @sigma2.invert
end