Class: Chainer::Links::Model::Classifier

Inherits:
Chain show all
Defined in:
lib/chainer/links/model/classifier.rb

Instance Attribute Summary collapse

Attributes inherited from Chainer::Link

#name

Instance Method Summary collapse

Methods inherited from Chain

#del_attr, #namedlinks, #namedparams, #params, #serialize, #set_attr

Methods inherited from Chainer::Link

#cleargrads, #del_attr, #init_scope, #namedlinks, #namedparams, #params, #register_persistent, #serialize, #set_attr, #within_init_scope

Constructor Details

#initialize(predictor, lossfun = Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun = Functions::Evaluation::Accuracy.method(:accuracy), label_key = -1)) ⇒ Classifier

Returns a new instance of Classifier.

Parameters:

  • predictor (Chainer::Link)

    Predictor network.

  • lossfun (Function) (defaults to: Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy))

    Loss function.

  • accfun (Function) (defaults to: Functions::Evaluation::Accuracy.method(:accuracy))

    Function that computes accuracy.

  • label_key (Integer, String) (defaults to: -1))

    Key to specify label variable from arguments. When it is Integer, a variable in positional arguments is used. And when it is String, a variable in keyword arguments is used.



13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/chainer/links/model/classifier.rb', line 13

def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1)
  super()

  unless label_key.is_a?(Integer) || label_key.is_a?(String)
    raise TypeError, "label_key must be Integer or String, but is #{label_key.class}"
  end

  @lossfun = lossfun
  @accfun = accfun
  @y = nil
  @loss = nil
  @accuracy = nil
  @compute_accuracy = true
  @label_key = label_key

  init_scope do
    @predictor = predictor
  end
end

Instance Attribute Details

#compute_accuracyObject

Returns the value of attribute compute_accuracy.



5
6
7
# File 'lib/chainer/links/model/classifier.rb', line 5

def compute_accuracy
  @compute_accuracy
end

Instance Method Details

#call(*args, **kwargs) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# File 'lib/chainer/links/model/classifier.rb', line 33

def call(*args, **kwargs)
  if @label_key.is_a?(Integer)
    raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size
    t = args.slice!(@label_key)
  elsif @label_key.is_a?(String)
    raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key)
    t = kwargs[@label_key]
    kwargs.delete(@label_key)
  end

  @y = nil
  @accuracy = nil
  @y = @predictor.(*args, **kwargs)

  @loss = @lossfun.call(@y, t)
  Chainer::Reporter.save_report({loss: @loss}, self)
  if @compute_accuracy
    @accuracy = @accfun.call(@y, t)
    Chainer::Reporter.save_report({accuracy: @accuracy}, self)
  end
  @loss
end