Class: Chainer::Links::Model::Classifier
- Inherits:
-
Chain
- Object
- Chainer::Link
- Chain
- Chainer::Links::Model::Classifier
- Defined in:
- lib/chainer/links/model/classifier.rb
Instance Attribute Summary collapse
-
#compute_accuracy ⇒ Object
Returns the value of attribute compute_accuracy.
Attributes inherited from Chainer::Link
Instance Method Summary collapse
- #call(*args, **kwargs) ⇒ Object
-
#initialize(predictor, lossfun = Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun = Functions::Evaluation::Accuracy.method(:accuracy), label_key = -1)) ⇒ Classifier
constructor
A new instance of Classifier.
Methods inherited from Chain
#del_attr, #namedlinks, #namedparams, #params, #serialize, #set_attr
Methods inherited from Chainer::Link
#cleargrads, #del_attr, #init_scope, #namedlinks, #namedparams, #params, #register_persistent, #serialize, #set_attr, #within_init_scope
Constructor Details
#initialize(predictor, lossfun = Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun = Functions::Evaluation::Accuracy.method(:accuracy), label_key = -1)) ⇒ Classifier
Returns a new instance of Classifier.
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# File 'lib/chainer/links/model/classifier.rb', line 13 def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1) super() unless label_key.is_a?(Integer) || label_key.is_a?(String) raise TypeError, "label_key must be Integer or String, but is #{label_key.class}" end @lossfun = lossfun @accfun = accfun @y = nil @loss = nil @accuracy = nil @compute_accuracy = true @label_key = label_key init_scope do @predictor = predictor end end |
Instance Attribute Details
#compute_accuracy ⇒ Object
Returns the value of attribute compute_accuracy.
5 6 7 |
# File 'lib/chainer/links/model/classifier.rb', line 5 def compute_accuracy @compute_accuracy end |
Instance Method Details
#call(*args, **kwargs) ⇒ Object
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# File 'lib/chainer/links/model/classifier.rb', line 33 def call(*args, **kwargs) if @label_key.is_a?(Integer) raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size t = args.slice!(@label_key) elsif @label_key.is_a?(String) raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key) t = kwargs[@label_key] kwargs.delete(@label_key) end @y = nil @accuracy = nil @y = @predictor.(*args, **kwargs) @loss = @lossfun.call(@y, t) Chainer::Reporter.save_report({loss: @loss}, self) if @compute_accuracy @accuracy = @accfun.call(@y, t) Chainer::Reporter.save_report({accuracy: @accuracy}, self) end @loss end |