Class: XGBoost::EarlyStopping

Inherits:
TrainingCallback show all
Defined in:
lib/xgboost/early_stopping.rb

Instance Method Summary collapse

Methods inherited from TrainingCallback

#before_iteration

Constructor Details

#initialize(rounds:, metric_name: nil, data_name: nil, maximize: nil, save_best: false, min_delta: 0.0) ⇒ EarlyStopping

Returns a new instance of EarlyStopping.



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/xgboost/early_stopping.rb', line 3

def initialize(
  rounds:,
  metric_name: nil,
  data_name: nil,
  maximize: nil,
  save_best: false,
  min_delta: 0.0
)
  @data = data_name
  @metric_name = metric_name
  @rounds = rounds
  @save_best = save_best
  @maximize = maximize
  @stopping_history = {}
  @min_delta = min_delta
  if @min_delta < 0
    raise ArgumentError, "min_delta must be greater or equal to 0."
  end

  @current_rounds = 0
  @best_scores = {}
  @starting_round = 0
  super()
end

Instance Method Details

#after_iteration(model, epoch, evals_log) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# File 'lib/xgboost/early_stopping.rb', line 33

def after_iteration(model, epoch, evals_log)
  epoch += @starting_round
  msg = "Must have at least 1 validation dataset for early stopping."
  if evals_log.keys.length < 1
    raise ArgumentError, msg
  end

  # Get data name
  if @data
    data_name = @data
  else
    # Use the last one as default.
    data_name = evals_log.keys[-1]
  end
  if !evals_log.include?(data_name)
    raise ArgumentError, "No dataset named: #{data_name}"
  end

  if !data_name.is_a?(String)
    raise TypeError, "The name of the dataset should be a string. Got: #{data_name.class.name}"
  end
  data_log = evals_log[data_name]

  # Get metric name
  if @metric_name
    metric_name = @metric_name
  else
    # Use last metric by default.
    metric_name = data_log.keys[-1]
  end
  if !data_log.include?(metric_name)
    raise ArgumentError, "No metric named: #{metric_name}"
  end

  # The latest score
  score = data_log[metric_name][-1]
  update_rounds(
    score, data_name, metric_name, model, epoch
  )
end

#after_training(model) ⇒ Object



74
75
76
77
78
79
80
81
82
83
84
85
# File 'lib/xgboost/early_stopping.rb', line 74

def after_training(model)
  if !@save_best
    return model
  end

  best_iteration = model.best_iteration
  best_score = model.best_score
  # model = model[..(best_iteration + 1)]
  model.best_iteration = best_iteration
  model.best_score = best_score
  model
end

#before_training(model) ⇒ Object



28
29
30
31
# File 'lib/xgboost/early_stopping.rb', line 28

def before_training(model)
  @starting_round = model.num_boosted_rounds
  model
end