Analysis of Classification-based Error Functions Mike Rimer Dr. Tony Martinez BYU Computer Science Dept. 18 March 2006
Overview Machine learning Teaching artificial neural networks with an error function Problems with conventional error functions CB algorithms Experimental results Conclusion and future work
Machine Learning Goal: Automating learning of problem domains Given a training sample from a problem domain, induce a correct solution-hypothesis over the entire problem population The learning model is often used as a black box inputoutput f (x)
Teaching ANNs with an Error Function Used to train a multi-layer perceptron (MLP) to guide the gradient descent learning procedure to an optimal state Conventional error metrics are sum-squared error (SSE) and cross entropy (CE) SSE suited to function approximation CE aimed at classification problems CB error functions [Rimer & Martinez 06] work better for classification
SSE, CE Attempts to approximate 0-1 targets in order to represent making a decision 0 1 O2O2 O1O1 ERROR 2ERROR 1 Pattern labeled as class 2
Issues with approximating hard targets Requires weights to be large to achieve optimality Leads to premature weight saturation Weight decay, etc., can improve the situation Learns areas of the problem space unevenly and at different times during training Makes global learning problematic
Classification-based Error Functions Designed to more closely match the goal of learning a classification task (i.e. correct classifications, not low error on 0-1 targets), avoiding premature weight saturation and discouraging overfit CB1 [Rimer & Martinez 02, 06] CB2 [Rimer & Martinez 04] CB3 (submitted to ICML ‘06)
CB1 Only backpropagates error on misclassified training patterns 0 1 Correct T~T 0 1 Misclassified T~T ERROR
CB2 Adds a confidence margin, μ, that is increased globally as training progresses 0 1 Misclassified T~T ERROR μ 0 1 ~TT ERROR μ Correct, but doesn’t satisfy margin 0 1 Correct, and satisfies margin T~T μ
CB3 Learns a confidence C i for each training pattern i as training progresses Patterns often misclassified have low confidence Patterns consistently classified correctly gain confidence 0 1 Misclassified T~T ERROR 0 1 ~TT ERROR CiCi Correct with learned low confidence 0 1 ~TT ERROR CiCi Correct with learned high confidence
Neural Network Training Influenced by: Initial parameter (weight) settings Pattern order presentation (stochastic training) Learning rate # of hidden nodes Goal of training: High generalization Low bias and variance
Experiments Empirical comparison of six error functions SSE, CE, CE w/ WD, CB1-3 Used eleven benchmark problems from the UC Irvine Machine Learning Repository ann, balance, bcw, derm, ecoli, iono, iris, musk2, pima, sonar, wine Testing performed using stratified 10-fold cross- validation Model selection by hold-out set Results were averaged over ten tests LR = 0.1, M = 0.7
Classifier output difference (COD) Evaluation of behavioral difference of two hypotheses (e.g. classifiers) T is the test set I is the identity or characteristic function
Robustness to initial network weights Averaged 30 random runs over all datasets algorithm % Test accSt DevEpoch CB CB CB CE CE w/ WD SSE
Robustness to initial network weights Averaged over all tests AlgorithmTest errorCOD CB CB CB CE CE w/ WD SSE
Robustness to pattern presentation order Averaged 30 random runs over all datasets algorithm % Test accSt DevEpoch CB CB CB CE CE w/ WD SSE
Robustness to pattern presentation order Averaged over all tests AlgorithmTest errorCOD CB CB CB CE CE w/ WD SSE
Robustness to learning rate Average of varying the learning rate from 0.01 – 0.3 AlgorithmTest accSt DevEpoch CB CB SSE CB CE CE w/ WD
Robustness to learning rate
Robustness to number of hidden nodes Average of varying the number of nodes in the hidden layer from AlgorithmTest accSt devEpoch CB CB CB SSE CE CE w/ WD
Robustness to number of hidden nodes
Conclusion CB1-3 are generally more robust than SSE, CE, and CE w/ WD, with respect to: Initial weight settings Pattern presentation order Pattern variance Learning rate # hidden nodes CB3 is superior, most robust, with most consistent results
Questions?