Class: Spark::Mllib::SVMWithSGD

Inherits:
ClassificationMethodBase show all
Defined in:
lib/spark/mllib/classification/svm.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  iterations: 100,
  step: 1.0,
  reg_param: 0.01,
  mini_batch_fraction: 1.0,
  initial_weights: nil,
  reg_type: 'l2',
  intercept: false
}

Class Method Summary collapse

Class Method Details

.train(rdd, options = {}) ⇒ Object

Train a support vector machine on the given data.

rdd

The training data, an RDD of LabeledPoint.

iterations

The number of iterations (default: 100).

step

The step parameter used in SGD (default: 1.0).

reg_param

The regularizer parameter (default: 0.01).

mini_batch_fraction

Fraction of data to be used for each SGD iteration.

initial_weights

The initial weights (default: nil).

reg_type

The type of regularizer used for training our model (default: “l2”).

Allowed values:

  • “l1” for using L1 regularization

  • “l2” for using L2 regularization

  • nil for no regularization

intercept

Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not).



118
119
120
121
122
123
124
125
126
127
128
129
130
131
# File 'lib/spark/mllib/classification/svm.rb', line 118

def self.train(rdd, options={})
  super

  weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainSVMModelWithSGD', rdd,
                                     options[:iterations].to_i,
                                     options[:step].to_f,
                                     options[:reg_param].to_f,
                                     options[:mini_batch_fraction].to_f,
                                     options[:initial_weights],
                                     options[:reg_type],
                                     options[:intercept])

  SVMModel.new(weights, intercept)
end