Class: Spark::Mllib::LogisticRegressionWithLBFGS

Inherits:
ClassificationMethodBase show all
Defined in:
lib/spark/mllib/classification/logistic_regression.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  iterations: 100,
  initial_weights: nil,
  reg_param: 0.01,
  reg_type: 'l2',
  intercept: false,
  corrections: 10,
  tolerance: 0.0001
}

Class Method Summary collapse

Class Method Details

.train(rdd, options = {}) ⇒ Object

Train a logistic regression model on the given data.

Arguments:

rdd

The training data, an RDD of LabeledPoint.

iterations

The number of iterations (default: 100).

initial_weights

The initial weights (default: nil).

reg_param

The regularizer parameter (default: 0.01).

reg_type

The type of regularizer used for training our model (default: “l2”).

Allowed values:

  • “l1” for using L1 regularization

  • “l2” for using L2 regularization

  • nil for no regularization

intercept

Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not).

corrections

The number of corrections used in the LBFGS update (default: 10).

tolerance

The convergence tolerance of iterations for L-BFGS (default: 0.0001).



206
207
208
209
210
211
212
213
214
215
216
217
218
219
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 206

def self.train(rdd, options={})
  super

  weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithLBFGS', rdd,
                                     options[:iterations].to_i,
                                     options[:initial_weights],
                                     options[:reg_param].to_f,
                                     options[:reg_type],
                                     options[:intercept],
                                     options[:corrections].to_i,
                                     options[:tolerance].to_f)

  LogisticRegressionModel.new(weights, intercept)
end