Class: XGBoost::EarlyStopping
- Inherits:
-
TrainingCallback
- Object
- TrainingCallback
- XGBoost::EarlyStopping
- Defined in:
- lib/xgboost/early_stopping.rb
Instance Method Summary collapse
- #after_iteration(model, epoch, evals_log) ⇒ Object
- #after_training(model) ⇒ Object
- #before_training(model) ⇒ Object
-
#initialize(rounds:, metric_name: nil, data_name: nil, maximize: nil, save_best: false, min_delta: 0.0) ⇒ EarlyStopping
constructor
A new instance of EarlyStopping.
Methods inherited from TrainingCallback
Constructor Details
#initialize(rounds:, metric_name: nil, data_name: nil, maximize: nil, save_best: false, min_delta: 0.0) ⇒ EarlyStopping
Returns a new instance of EarlyStopping.
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# File 'lib/xgboost/early_stopping.rb', line 3 def initialize( rounds:, metric_name: nil, data_name: nil, maximize: nil, save_best: false, min_delta: 0.0 ) @data = data_name @metric_name = metric_name @rounds = rounds @save_best = save_best @maximize = maximize @stopping_history = {} @min_delta = min_delta if @min_delta < 0 raise ArgumentError, "min_delta must be greater or equal to 0." end @current_rounds = 0 @best_scores = {} @starting_round = 0 super() end |
Instance Method Details
#after_iteration(model, epoch, evals_log) ⇒ Object
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# File 'lib/xgboost/early_stopping.rb', line 33 def after_iteration(model, epoch, evals_log) epoch += @starting_round msg = "Must have at least 1 validation dataset for early stopping." if evals_log.keys.length < 1 raise ArgumentError, msg end # Get data name if @data data_name = @data else # Use the last one as default. data_name = evals_log.keys[-1] end if !evals_log.include?(data_name) raise ArgumentError, "No dataset named: #{data_name}" end if !data_name.is_a?(String) raise TypeError, "The name of the dataset should be a string. Got: #{data_name.class.name}" end data_log = evals_log[data_name] # Get metric name if @metric_name metric_name = @metric_name else # Use last metric by default. metric_name = data_log.keys[-1] end if !data_log.include?(metric_name) raise ArgumentError, "No metric named: #{metric_name}" end # The latest score score = data_log[metric_name][-1] update_rounds( score, data_name, metric_name, model, epoch ) end |
#after_training(model) ⇒ Object
74 75 76 77 78 79 80 81 82 83 84 85 |
# File 'lib/xgboost/early_stopping.rb', line 74 def after_training(model) if !@save_best return model end best_iteration = model.best_iteration best_score = model.best_score # model = model[..(best_iteration + 1)] model.best_iteration = best_iteration model.best_score = best_score model end |
#before_training(model) ⇒ Object
28 29 30 31 |
# File 'lib/xgboost/early_stopping.rb', line 28 def before_training(model) @starting_round = model.num_boosted_rounds model end |