머신러닝

[머신러닝] GridSearchCV

서노리 2022. 6. 8. 02:25
반응형

GridSearchCV

GridSearchCV란 교차 검증 점수를 기반으로 머신러닝 모델의 최적의 파라매터를 검색할 수 있는 클래스이다.

GridSearchCV 클래스를 사용하지 않는다면 다중 for문을 통해 최적의 파라매터를 찾아야 하는데 이는 제어할 파라매터의 종류가 많아질수록 가독성이 매우 떨어진다.

 

※ GridSearchCV의 하이퍼 파라매터

  • estimator : 예측기 객체 (Classifier, Regressor, Pipeline 등)
  • param_grid : 사용할 파라메터가 정의된 dictionary
  • cv : 교차검증 개수 (KFold 객체를 넣을 수도 있다)

GridSearchCV 예제

from sklearn.ensemble import GradientBoostingClassifier

# 테스트할 파라매터를 제외하고 모델을 정의해준다.
base_model=GradientBoostingClassifier(random_state=1)

# 모델의 학습에 사용할 파라메터의 정의
param_grid = {'learning_rate':[0.1, 0.2, 0.3, 1., 0.01],
              'max_depth':[1, 2, 3],
              'n_estimators':[100, 200, 300, 10, 50]}

from sklearn.model_selection KFold   
from sklearn.model_selection import GridSearchCV

cv=KFold(n_splits=5,shuffle=True,random_state=1)
grid_model = GridSearchCV(estimator=base_model,
                          param_grid=param_grid,
                          cv=cv,
                          n_jobs=-1)
grid_model.fit(X_train,y_train)

# 모든 하이퍼 파라메터를 조합하여 평가한 
# 가장 높은 교차검증 SCORE 값을 반환
print(f'best_score -> {grid_model.best_score_}')
# 가장 높은 교차검증 SCORE 가 어떤 
# 하이퍼 파라메터를 조합했을 때 만들어 졌는지 확인
print(f'best_params -> {grid_model.best_params_}')
# 가장 높은 교차검증 SCORE의 
# 하이퍼 파라메터를 사용하여 생성된 모델 객체를 반환
print(f'best_model -> {grid_model.best_estimator_}')

 

※ 조건부 매개변수를 사용하기 위한 param_grid 선언 방법

LogisticRegression 클래스의 solver와 penalty 파라매터는 조합에 따라 성공, 실패가 될 수 있기 때문에
일반적인 방식으로 param_grid를 선언해주면 에러가 발생한다.
따라서 다음과 같은 방법으로 param_grid를 선언해 사용해야 한다.

# 사용 방식 
# - [{조건부 매개변수 1}, {조건부 매개변수 2} ... ]
# - 아래의 매개변수 그리드는 
#   1번째 l1 penalty에 대한 매개변수 그리드
#   2번째 l2 penalty에 대한 매개변수 그리드
#   3번째 elasticnet penalty에 대한 매개변수 그리드

param_grid = [{'C':[1., 0.1, 0.01, 10., 100],
              'penalty':['l1'],
              'solver':['liblinear', 'saga']},
              {'C':[1., 0.1, 0.01, 10., 100],
              'penalty':['l2'],
              'solver':['lbfgs', 'sag', 'saga']},
              {'C':[1., 0.1, 0.01, 10., 100],
              'penalty':['elasticnet'],
              'solver':['saga']}]

 

반응형