Class: Spark::Mllib::LogisticRegressionWithSGD

Inherits:
ClassificationMethodBase show all
Defined in:
lib/spark/mllib/classification/logistic_regression.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  iterations: 100,
  step: 1.0,
  mini_batch_fraction: 1.0,
  initial_weights: nil,
  reg_param: 0.01,
  reg_type: 'l2',
  intercept: false
}

Class Method Summary collapse

Class Method Details

.train(rdd, options = {}) ⇒ Object

Train a logistic regression model on the given data.

Arguments:

rdd

The training data, an RDD of LabeledPoint.

iterations

The number of iterations (default: 100).

step

The step parameter used in SGD (default: 1.0).

mini_batch_fraction

Fraction of data to be used for each SGD iteration.

initial_weights

The initial weights (default: nil).

reg_param

The regularizer parameter (default: 0.01).

reg_type

The type of regularizer used for training our model (default: “l2”).

Allowed values:

  • “l1” for using L1 regularization

  • “l2” for using L2 regularization

  • nil for no regularization

intercept

Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not).



138
139
140
141
142
143
144
145
146
147
148
149
150
151
# File 'lib/spark/mllib/classification/logistic_regression.rb', line 138

def self.train(rdd, options={})
  super

  weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithSGD', rdd,
                                     options[:iterations].to_i,
                                     options[:step].to_f,
                                     options[:mini_batch_fraction].to_f,
                                     options[:initial_weights],
                                     options[:reg_param].to_f,
                                     options[:reg_type],
                                     options[:intercept])

  LogisticRegressionModel.new(weights, intercept)
end