plot_learning_curve#
- scikitplot.api.estimators.plot_learning_curve(estimator, X, y, *, train_sizes=None, cv=None, scoring=None, n_jobs=None, verbose=0, shuffle=False, random_state=None, fit_params=None, title='Learning Curves', ax=None, fig=None, figsize=None, title_fontsize='large', text_fontsize='medium', **kwargs)[source]#
Generates a plot of the train and test learning curves for a classifier.
The learning curves plot the performance of a classifier as a function of the number of training samples. This helps in understanding how well the classifier performs with different amounts of training data.
- Parameters:
- estimatorobject type that implements the “fit” method
An object of that type which is cloned for each validation. It must also implement “predict” unless
scoring
is a callable that doesn’t rely on “predict” to compute a score.- Xarray-like, shape (n_samples, n_features)
Training data, where
n_samples
is the number of samples andn_features
is the number of features.- yarray-like, shape (n_samples,) or (n_samples, n_features), optional
Target relative to
X
for classification or regression. None for unsupervised learning.- train_sizesiterable, optional
Determines the training sizes used to plot the learning curve. If None,
np.linspace(.1, 1.0, 5)
is used.- cvint, cross-validation generator, iterable or None, default=5
Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - integer, to specify the number of folds. - CV splitter, - An iterable that generates (train, test) splits as arrays of indices.
For integer/None inputs, if classifier is True and
y
is either binary or multiclass,StratifiedKFold
is used. In all other cases,KFold
is used.Refer User Guide for the various cross-validation strategies that can be used here.
- scoringstr, callable, or None, optional, default=None
A string (see scikit-learn model evaluation documentation) or a scorer callable object/function with signature
scorer(estimator, X, y)
.- n_jobsint, optional, default=None
Number of jobs to run in parallel. Training the estimator and computing the score are parallelized over the different training and test sets.
None
means 1 unless in ajoblib.parallel_backend
context.-1
means using all processors. See Glossary for more details.- verboseint, default=0
Controls the verbosity: the higher, the more messages.
- shufflebool, optional, default=True
Whether to shuffle the training data before splitting using cross-validation.
- random_stateint or RandomState, optional
Pseudo-random number generator state used for random sampling.
- fit_paramsdict, default=None
Parameters to pass to the fit method of the estimator.
Added in version 0.3.9.
- titlestr, optional, default=”Learning Curves”
Title of the generated plot.
- axmatplotlib.axes.Axes, optional, default=None
The axis to plot the figure on. If None is passed in the current axes will be used (or generated if required).
- figmatplotlib.pyplot.figure, optional, default: None
The figure to plot the Visualizer on. If None is passed in the current plot will be used (or generated if required).
- figsizetuple of int, optional, default=None
Tuple denoting figure size of the plot, e.g., (6, 6).
- title_fontsizestr or int, optional, default=’large’
Font size for the plot title. Use e.g., “small”, “medium”, “large” or integer values.
- text_fontsizestr or int, optional, default=’medium’
Font size for the text in the plot. Use e.g., “small”, “medium”, “large” or integer values.
- kwargs: dict
generic keyword arguments.
- Returns:
- matplotlib.axes.Axes
The axes on which the plot was drawn.
References * “scikit-learn learning_curve”.#
Examples
>>> from sklearn.datasets import load_digits as data_10_classes >>> from sklearn.model_selection import train_test_split >>> from sklearn.naive_bayes import GaussianNB >>> import scikitplot as skplt >>> X, y = data_10_classes(return_X_y=True, as_frame=False) >>> X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.5, random_state=0) >>> model = GaussianNB() >>> model.fit(X_train, y_train) >>> y_val_pred = model.predict(X_val) >>> skplt.estimators.plot_learning_curve( >>> model, X_val, y_val_pred, >>> );
(
Source code
,png
)
Gallery examples#
plot_learning_curve with examples