|On this page…|
Supervised learning (machine learning) takes a known set of input data and known responses to the data, and seeks to build a predictor model that generates reasonable predictions for the response to new data.
Suppose you want to predict if someone will have a heart attack within a year. You have a set of data on previous people, including age, weight, height, blood pressure, etc. You know if the previous people had heart attacks within a year of their data measurements. So the problem is combining all the existing data into a model that can predict whether a new person will have a heart attack within a year.
Supervised learning splits into two broad categories:
Classification for responses that can have just a few known values, such as 'true' or 'false'. Classification algorithms apply to nominal, not ordinal response values.
Regression for responses that are a real number, such as miles per gallon for a particular car.
You can have trouble deciding whether you have a classification problem or a regression problem. In that case, create a regression model first, because they are often more computationally efficient.
While there are many Statistics Toolbox™ algorithms for supervised learning, most use the same basic workflow for obtaining a predictor model. (Detailed instruction on the steps for ensemble learning is in Framework for Ensemble Learning.) The steps for supervised learning are:
All supervised learning methods start with an input data matrix, usually called X here. Each row of X represents one observation. Each column of X represents one variable, or predictor. Represent missing entries with NaN values in X. Statistics Toolbox supervised learning algorithms can handle NaN values, either by ignoring them or by ignoring any row with a NaN value.
You can use various data types for response data Y. Each element in Y represents the response to the corresponding row of X. Observations with missing Y data are ignored.
For regression, Y must be a numeric vector with the same number of elements as the number of rows of X.
For classification, Y can be any of these data types. This table also contains the method of including missing entries.
|Data Type||Missing Entry|
|Character array||Row of spaces|
|Cell array of strings||''|
|Logical vector||(Cannot represent)|
There are tradeoffs between several characteristics of algorithms, such as:
Speed of training
Predictive accuracy on new data
Transparency or interpretability, meaning how easily you can understand the reasons an algorithm makes its predictions
The fitting function you use depends on the algorithm you choose.
|Discriminant Analysis (classification)||ClassificationDiscriminant.fit|
|K-Nearest Neighbors (classification)||ClassificationKNN.fit|
|Naive Bayes (classification)||NaiveBayes.fit|
|Classification or Regression Ensembles||fitensemble|
|Classification or Regression Ensembles in Parallel||TreeBagger|
The three main methods to examine the accuracy of the resulting fitted model are:
Examine the resubstitution error. For examples, see:
Examine the cross-validation error. For examples, see:
Examine the out-of-bag error for bagged decision trees. For examples, see:
After validating the model, you might want to change it for better accuracy, better speed, or to use less memory.
Change fitting parameters to try to get a more accurate model. For examples, see:
Change fitting parameters to try to get a smaller model. This sometimes gives a model with more accuracy. For examples, see:
Try a different algorithm. For applicable choices, see:
When satisfied with a model of some types, you can trim it using the appropriate compact method (compact for classification trees, compact for classification ensembles, compact for regression trees, compact for regression ensembles, compact for discriminant analysis). compact removes training data and pruning information, so the model uses less memory.
To predict classification or regression response for most fitted models, use the predict method:
Ypredicted = predict(obj,Xnew)
obj is the fitted model object.
Xnew is the new input data.
Ypredicted is the predicted response, either classification or regression.
This table shows typical characteristics of the various supervised learning algorithms. The characteristics in any particular case can vary from the listed ones. Use the table as a guide for your initial choice of algorithms, but be aware that the table can be inaccurate for some problems.
Characteristics of Supervised Learning Algorithms
|Algorithm||Predictive Accuracy||Fitting Speed||Prediction Speed||Memory Usage||Easy to Interpret||Handles Categorical Predictors|
|Ensembles||See Suggestions for Choosing an Appropriate Ensemble Algorithm and General Characteristics of Ensemble Algorithms|
* — SVM prediction speed and memory usage are good if there are few support vectors, but can be poor if there are many support vectors. When you use a kernel function, it can be difficult to interpret how SVM classifies data, though the default linear scheme is easy to interpret.
** — Naive Bayes speed and memory usage are good for simple distributions, but can be poor for kernel distributions and large data sets.
*** — Nearest Neighbor usually has good predictions in low dimensions, but can have poor predictions in high dimensions. For linear search, Nearest Neighbor does not perform any fitting. For kd-trees, Nearest Neighbor does perform fitting. Nearest Neighbor can have either continuous or categorical predictors, but not both.
**** — Discriminant Analysis is accurate when the modeling assumptions are satisfied (multivariate normal by class). Otherwise, the predictive accuracy varies.