Class: IrtRuby::RaschModel

Inherits:
Object
  • Object
show all
Defined in:
lib/irt_ruby/rasch_model.rb

Overview

A class representing the Rasch model for Item Response Theory (ability - difficulty). Incorporates:

  • Adaptive learning rate

  • Missing data handling (skip nil)

  • Multiple convergence checks (log-likelihood + parameter updates)

Constant Summary collapse

MISSING_STRATEGIES =
%i[ignore treat_as_incorrect treat_as_correct].freeze

Instance Method Summary collapse

Constructor Details

#initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6, learning_rate: 0.01, decay_factor: 0.5, missing_strategy: :ignore) ⇒ RaschModel

Returns a new instance of RaschModel.

Raises:

  • (ArgumentError)


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# File 'lib/irt_ruby/rasch_model.rb', line 12

def initialize(data,
               max_iter: 1000,
               tolerance: 1e-6,
               param_tolerance: 1e-6,
               learning_rate: 0.01,
               decay_factor: 0.5,
               missing_strategy: :ignore)
  # data: A Matrix or array-of-arrays of responses (0/1 or nil for missing).
  # missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct

  @data = data
  @data_array = data.to_a
  num_rows = @data_array.size
  num_cols = @data_array.first.size

  raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy)

  @missing_strategy = missing_strategy

  # Initialize parameters near zero
  @abilities    = Array.new(num_rows)  { rand(-0.25..0.25) }
  @difficulties = Array.new(num_cols)  { rand(-0.25..0.25) }

  @max_iter        = max_iter
  @tolerance       = tolerance
  @param_tolerance = param_tolerance
  @learning_rate   = learning_rate
  @decay_factor    = decay_factor
end

Instance Method Details

#apply_gradient_update(grad_abilities, grad_difficulties) ⇒ Object



97
98
99
100
101
102
103
104
105
106
107
108
109
110
# File 'lib/irt_ruby/rasch_model.rb', line 97

def apply_gradient_update(grad_abilities, grad_difficulties)
  old_abilities    = @abilities.dup
  old_difficulties = @difficulties.dup

  @abilities.each_index do |i|
    @abilities[i] += @learning_rate * grad_abilities[i]
  end

  @difficulties.each_index do |j|
    @difficulties[j] += @learning_rate * grad_difficulties[j]
  end

  [old_abilities, old_difficulties]
end

#average_param_update(old_abilities, old_difficulties) ⇒ Object



112
113
114
115
116
117
118
119
120
121
# File 'lib/irt_ruby/rasch_model.rb', line 112

def average_param_update(old_abilities, old_difficulties)
  deltas = []
  @abilities.each_with_index do |a, i|
    deltas << (a - old_abilities[i]).abs
  end
  @difficulties.each_with_index do |d, j|
    deltas << (d - old_difficulties[j]).abs
  end
  deltas.sum / deltas.size
end

#compute_gradientObject



77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# File 'lib/irt_ruby/rasch_model.rb', line 77

def compute_gradient
  grad_abilities    = Array.new(@abilities.size, 0.0)
  grad_difficulties = Array.new(@difficulties.size, 0.0)

  @data_array.each_with_index do |row, i|
    row.each_with_index do |resp, j|
      value, skip = resolve_missing(resp)
      next if skip

      prob = sigmoid(@abilities[i] - @difficulties[j])
      error = value - prob

      grad_abilities[i]    += error
      grad_difficulties[j] -= error
    end
  end

  [grad_abilities, grad_difficulties]
end

#fitObject



123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# File 'lib/irt_ruby/rasch_model.rb', line 123

def fit
  prev_ll = log_likelihood

  @max_iter.times do
    grad_abilities, grad_difficulties = compute_gradient

    old_a, old_d = apply_gradient_update(grad_abilities, grad_difficulties)

    current_ll  = log_likelihood
    param_delta = average_param_update(old_a, old_d)

    if current_ll < prev_ll
      @abilities    = old_a
      @difficulties = old_d
      @learning_rate *= @decay_factor
    else
      ll_diff = (current_ll - prev_ll).abs
      break if ll_diff < @tolerance && param_delta < @param_tolerance

      prev_ll = current_ll
    end
  end

  { abilities: @abilities, difficulties: @difficulties }
end

#log_likelihoodObject



59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# File 'lib/irt_ruby/rasch_model.rb', line 59

def log_likelihood
  total_ll = 0.0
  @data_array.each_with_index do |row, i|
    row.each_with_index do |resp, j|
      value, skip = resolve_missing(resp)
      next if skip

      prob = sigmoid(@abilities[i] - @difficulties[j])
      total_ll += if value == 1
                    Math.log(prob + 1e-15)
                  else
                    Math.log((1 - prob) + 1e-15)
                  end
    end
  end
  total_ll
end

#resolve_missing(resp) ⇒ Object



46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/irt_ruby/rasch_model.rb', line 46

def resolve_missing(resp)
  return [resp, false] unless resp.nil?

  case @missing_strategy
  when :ignore
    [nil, true]
  when :treat_as_incorrect
    [0, false]
  when :treat_as_correct
    [1, false]
  end
end

#sigmoid(x) ⇒ Object



42
43
44
# File 'lib/irt_ruby/rasch_model.rb', line 42

def sigmoid(x)
  1.0 / (1.0 + Math.exp(-x))
end