Class: XGBoost::CallbackContainer

Inherits:
Object
  • Object
show all
Defined in:
lib/xgboost/callback_container.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(callbacks, is_cv: false) ⇒ CallbackContainer

Returns a new instance of CallbackContainer.



5
6
7
8
9
10
11
12
13
14
15
# File 'lib/xgboost/callback_container.rb', line 5

def initialize(callbacks, is_cv: false)
  @callbacks = callbacks
  callbacks.each do |callback|
    unless callback.is_a?(TrainingCallback)
      raise TypeError, "callback must be an instance of XGBoost::TrainingCallback"
    end
  end

  @history = {}
  @is_cv = is_cv
end

Instance Attribute Details

#aggregated_cvObject (readonly)

Returns the value of attribute aggregated_cv.



3
4
5
# File 'lib/xgboost/callback_container.rb', line 3

def aggregated_cv
  @aggregated_cv
end

#historyObject (readonly)

Returns the value of attribute history.



3
4
5
# File 'lib/xgboost/callback_container.rb', line 3

def history
  @history
end

Instance Method Details

#after_iteration(model, epoch, dtrain, evals) ⇒ Object



55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/xgboost/callback_container.rb', line 55

def after_iteration(model, epoch, dtrain, evals)
  if @is_cv
    scores = model.eval_set(epoch)
    scores = aggcv(scores)
    @aggregated_cv = scores
    update_history(scores, epoch)
  else
    evals ||= []
    evals.each do |_, name|
      if name.include?("-")
        raise ArgumentError, "Dataset name should not contain `-`"
      end
    end
    score = model.eval_set(evals, epoch)
    metric_score = parse_eval_str(score)
    update_history(metric_score, epoch)
  end

  @callbacks.any? do |callback|
    callback.after_iteration(model, epoch, @history)
  end
end

#after_training(model) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# File 'lib/xgboost/callback_container.rb', line 33

def after_training(model)
  @callbacks.each do |callback|
    model = callback.after_training(model)
    if @is_cv
      unless model.is_a?(PackedBooster)
        raise TypeError, "after_training should return the model"
      end
    else
      unless model.is_a?(Booster)
        raise TypeError, "after_training should return the model"
      end
    end
  end
  model
end

#before_iteration(model, epoch, dtrain, evals) ⇒ Object



49
50
51
52
53
# File 'lib/xgboost/callback_container.rb', line 49

def before_iteration(model, epoch, dtrain, evals)
  @callbacks.any? do |callback|
    callback.before_iteration(model, epoch, @history)
  end
end

#before_training(model) ⇒ Object



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/xgboost/callback_container.rb', line 17

def before_training(model)
  @callbacks.each do |callback|
    model = callback.before_training(model)
    if @is_cv
      unless model.is_a?(PackedBooster)
        raise TypeError, "before_training should return the model"
      end
    else
      unless model.is_a?(Booster)
        raise TypeError, "before_training should return the model"
      end
    end
  end
  model
end