watex.models.GridSearch#
- class watex.models.GridSearch(base_estimator, grid_params, cv=4, kind='GridSearchCV', scoring='nmse', verbose=0, **grid_kws)[source]#
Fine-tune hyperparameters using grid search methods.
Search Grid will be able to fiddle with the hyperparameters until to
- Parameters:
base_estimator (Callable,) – estimator for trainset and label evaluating; something like a class that implements a fit method. Refer to https://scikit-learn.org/stable/modules/classes.html
grid_params (list of dict,) –
list of hyperparameters params to be fine-tuned.For instance:
param_grid=[dict( kpca__gamma=np.linspace(0.03, 0.05, 10), kpca__kernel=["rbf", "sigmoid"] )]
pipeline (Callable or
Pipelineobject) – If pipeline is given , X is transformed accordingly, Otherwise evaluation is made using purely the base estimator with the given X.prefit (bool, default=False,) – If
False, does not need to compute the cross validation score once again andTrueotherwise.cv (float,) –
A cross validation splitting strategy. It used in cross-validation based routines. cv is also available in estimators such as multioutput. ClassifierChain or calibration.CalibratedClassifierCV which use the predictions of one estimator as training data for another, to not overfit the training supervision. Possible inputs for cv are usually:
* An integer, specifying the number of folds in K-fold cross validation. K-fold will be stratified over classes if the estimator is a classifier (determined by base.is_classifier) and the targets may represent a binary or multiclass (but not multioutput) classification problem (determined by utils.multiclass.type_of_target). * A cross-validation splitter instance. Refer to the User Guide for splitters available within `Scikit-learn`_ * An iterable yielding train/test splits.- With some exceptions (especially where not using cross validation at all
is an option), the default is
4-fold.
The default is
4.kind (str, default='GridSearchCV' or '1') – Kind of grid parameter searches. Can be
1forGridSearchCVor2forRandomizedSearchCV.scoring (str,) – Specifies the score function to be maximized (usually by cross validation), or – in some cases – multiple score functions to be reported. The score function can be a string accepted by
sklearn.metrics.get_scorer()or a callable scorer, not to be confused with an evaluation metric, as the latter have a more diverse API.scoringmay also be set to None, in which case the estimator’s score method is used. See slearn.scoring_parameter in the Scikit-learn User Guide.random_state (int, RandomState instance or None, default=None) – Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls..
Examples
>>> from pprint import pprint >>> from watex.datasets import fetch_data >>> from watex.models.validation import GridSearch >>> from watex.exlib.sklearn import RandomForestClassifier >>> X_prepared, y_prepared =fetch_data ('bagoue prepared') >>> grid_params = [ dict( ... n_estimators=[3, 10, 30], max_features=[2, 4, 6, 8]), ... dict(bootstrap=[False], n_estimators=[3, 10], ... max_features=[2, 3, 4]) ... ] >>> forest_clf = RandomForestClassifier() >>> grid_search = GridSearch(forest_clf, grid_params) >>> grid_search.fit(X= X_prepared,y = y_prepared,) >>> pprint(grid_search.best_params_ ) {'max_features': 8, 'n_estimators': 30} >>> pprint(grid_search.cv_results_)