본문 바로가기

머신러닝

08. 지도학습 (분류, 회귀) - 결정 트리 : 연봉 수준 예측하기

[공부 자료 : Must Have 데싸노트의 실전에서 통하는 머신러닝]

# 결정 트리

- 결정 트리

결정 트리 : 트리 기반 모델의 기본 모델 (classifier와 regression 둘 다 가능하지만, 주로 분류 문제에 사용)

트리 모델 vs. 선형 모델

      선형 모델 : 각 변수에 대한 기울기값들의 최적화를 통해 모델을 학습

      트리 모델 : 각 변수의 특정 지점을 기준으로 데이터를 분류하여 예측 모델을 생성 (데이터를 무수하게 쪼개서 예측)

특징 

      다른 트리 기반 모델을 설명하려면 결정 트리를 알아야 한다

      트리 기반 모델은 딥러닝을 제외하고는 가장 유용하고 많이 사용되는 머신러닝 알고리즘 

      종속 변수가 연속형 데이터와 범주형 데이터인 경우 모두에 사용 가능

      모델링 결과를 시각화할 목적으로 가장 유용

      아웃라이어가 문제가 될 정도로 많을 때 유용

장점 

      데이터에 대한 가정이 없다 (nonparammetric model)

      (선형 모델은 정규 분포에 대한 가정이나 독립변수-종속변수의 선형 관계 등을 가정)

      아웃라이어에 영향을 거의 받지 않는다

      트리 그래프를 통해 직관적으로 이해하고 설명할 수있다 (시각화가 탁월하다)

단점 

      트리가 무한정 깊어지면 오버피팅 문제가 발생

      예측력이 상당히 떨어진다

용어

      루트 노드 : 시작 노드

      규칙 노드 : 중간 노드

      리프 노드 : 가장 끝에 있는 노드

      깊이 : 루트 노드로부터 리프 노드로 가는데 총 몇 단계가 걸렸는가

            부모 노드 : 상위 노드

            자식 노드 : 하위 노드

            형제 노드 : 같은 깊이에 있는 노드 

- 결정 트리 이해하기

결정 트리 : 특정 변수에 대한 특정 기준값으로 데이터를 계속 분류해가면서 유사한 그룹으로 묶어 예측값을 만든다

      좋은 모델을 만들려면 첫 번째 분류에 사용할 변수 선정기준점 정하기가 매우 중요

      분류 결정 트리회귀 결정 트리에 적용되는 로직이 다르다

- 분류 결정 트리 (DecisionTreeClassifier)

분류 결정 트리 : 각 노드순도가 가장 높은 방향으로 분류를 한다

      순도 : 한 노드 안에 한 종류의 목표값만 있는 상태에 대한 지표 (사과 3개, 복숭아 0개 > 사과 3개, 복숭아 3개)

      결정 트리는 순도를 체크하여 가지를 뻗는다

순도를 평가하는 지표 : 둘 다 결정 트리에서 비슷한 성능을 보여준다 (scikit learn에서는 기본으로 지니 인덱스를 사용)

      1. 지니 인덱스 (Gini Index) : 0에 가까울수록 순도가 높다 (최대 0.5)

            지니 인덱스는 각 노드에 대해서 게산되며, p는 노드 안에 특정 아이템의 비율을 의미한다 

            ex. 사과 2, 복숭아 2 → 0.5

            ex. 사과 0, 복숭아 4 → 0

      2. 교차 엔트로피 (Cross Entropy) : 0에 가까울수록 순도가 높다 (최대 1)

            ex. 사과 2, 복숭아 2 → 1

            ex. 사과 0, 복숭아 4 → 0

- 회귀 결정 트리 (DecisionTreeRegressor)

평가 기준 : MSE

      특정 x를 기준으로 나눈다 → 각 그룹의 y의 평균을 구한다 → 평균과 각 값의 차를 구한 후 MSE를 구한다

      가장 낮은 MSE값이 나오도록 노드를 분류한다

- 예측력과 설명력

예측력 : 모델 학습을 통해 얼마나 좋은 예측치를 보여주는가

설명력 : 학습된 모델을 얼마나 쉽게 해석할 수 있는가

      복잡한 알고리즘은 예측력이 증가하고 설명력이 떨어진다 (앞으로 배울 알고리즘, 딥러닝 등)

      간단한 알고리즘은 예측력이 떨어지고 설명력이 증가한다 (결정 트리, 회귀 분석 등)

예측력이 높은 알고리즘을 써야하는 경우

      사기 거래를 예측하는 모델에서는 얼마나 정확하게 사기 거래를 잡아내는가가 중요

설명력이 높은 알고리즘을 써야하는 경우

      특정 질병의 발병률에 대한 예측 모델에서는 발병률을 높이거나 억제하는 중요 요인을 밝히기 위해 설명력이 중요

- 오버피팅 (Overfitting) 문제

결정 트리는 최대한 정확하게 분류할 때까지 가지를 뻗어나갈 수 있다

      각각의 마지막 노드1개의 관측치만 들어갈 정도로 세밀하게 분류한다면 100%의 정확도로 분류가 가능할 것이다

      하지만 실제로 이런 모델은 좋은 모델이라고 할 수 없다

머신러닝 모델에서 train set에 대하여 지나치게 정확도가 높은 모델은 test set에서 오히려 안 좋은 결과를 보여준다

      언더피팅 (underfitting, 과소적합) : 모델이 train set에 대해서도 충분한 학습이 되지 않음

      오버피팅 (overfitting, 과대적합) : 모델이 train set에 대해 너무 과하게 학습이 됨

            머신러닝 알고리즘이 복잡해짐에 따라 오버피팅은 발생할 수밖에 없으며, 정확도를 일부러 낮춰야 한다

 편향-분산 트레이드오프 (Bias-Variance Tradeoff)

      편향 : 독립변수와 종속변수를 모델링한 알고리즘이 적절하지 못하거나 제대로 된 예측을 하지 못할 때 증가

            (ex. 언더피팅 : 편향이 높은 상태)

      분산 : train set에 있는 데이터의 노이즈에 의해 발생하는 오차

            (ex. 오버피팅 : train set을 지나치게 정확하게 따라가며 모든 노이즈를 모델에 포함에 분산이 높은 상태)

      오차 (모델의 전체 error) = 편향 + 분산

            모델의 복잡도가 증가할수록 편향이 낮아지고 분산이 높아진다 (오버피팅)

            모델의 복잡도가 낮아질수록 편향이 높아지고 분산이 낮아진다 (언더피팅)

            목표 : 편향과 분산의 합이 최소가 되는 복잡도를 찾는다

# 결정 트리 - 연봉 수준 예측하기

- 1. 데이터 불러오기

import pandas as pd
import numpy as np

file_url = 'https://media.githubusercontent.com/media/musthave-ML10/data_source/main/salary.csv'
data = pd.read_csv(file_url, skipinitialspace = True)
# skipinitialspace : 각 데이터의 첫 자리에 있는 공란을 지워준다
# 공란이 있을 경우, 이후 전처리가 제대로 되지 않는다

- 2. 데이터 분석하기

data.columns
# Index(['age', 'workclass', 'education', 'education-num', 'marital-status',
#        'occupation', 'relationship', 'race', 'sex', 'capital-gain',
#        'capital-loss', 'hours-per-week', 'native-country', 'class'],
#       dtype='object')

종속 변수 : class (이진분류 문제)

data['class'].unique()
# array(['<=50K', '>50K'], dtype=object)

data['class'].nunique()
# 2

결측치 확인, object 범주 확인

data.info()
# RangeIndex: 48842 entries, 0 to 48841
# Data columns (total 14 columns):
#  #   Column          Non-Null Count  Dtype 
# ---  ------          --------------  ----- 
#  0   age             48842 non-null  int64 
#  1   workclass       46043 non-null  object - 결측치
#  2   education       48842 non-null  object
#  3   education-num   48842 non-null  int64 
#  4   marital-status  48842 non-null  object
#  5   occupation      46033 non-null  object - 결측치
#  6   relationship    48842 non-null  object
#  7   race            48842 non-null  object
#  8   sex             48842 non-null  object
#  9   capital-gain    48842 non-null  int64 
#  10  capital-loss    48842 non-null  int64 
#  11  hours-per-week  48842 non-null  int64 
#  12  native-country  47985 non-null  object - 결측치
#  13  class           48842 non-null  object

# object : 텍스트로 구성된 범주형 변수

데이터 통계 분석

data.describe()
# object형 데이터를 제거하고 출력

data.describe(include = 'all')
# object형 변수까지 포함하여 출력
# unique : object형 변수에서 고유값의 개수
# top : object형 변수에서 가장 많이 등장한 value
# freq : object형 변수에서 top value가 몇 번 나왔는가

- 3-1. 데이터 전처리 : 종속변수 형태 변경

# map({ })을 이용하여, 연봉이 50K 이하일 경우 '0', 50K 초과일 경우 '1'
data['class'] = data['class'].map({'<=50K':0, '>50K':1})
data['class']

# target 설정
y = data['class']

- 3-2. object형 변수 정보 확인

독립변수의 범주형 데이터에 대한 처리

      독립변수 중 범주형 데이터가 얼마나 있는지 확인하기

      .dtype : 변수의 자료형

for i in data.columns:
    print(i, ' - ', data[i].dtype)

# age  -  int64
# workclass  -  object
# education  -  object
# education-num  -  int64
# marital-status  -  object
# occupation  -  object
# relationship  -  object
# race  -  object
# sex  -  object
# capital-gain  -  int64
# capital-loss  -  int64
# hours-per-week  -  int64
# native-country  -  object
# class  -  float64

object형 변수의 이름을 따로 모으기

obj_list = [] # object형 변수 이름을 모을 리스트

for i in data.columns:
    if data[i].dtype == 'object':
        obj_list.append(i)

obj_list
# ['workclass',
#  'education',
#  'marital-status',
#  'occupation',
#  'relationship',
#  'race',
#  'sex',
#  'native-country']

범주형 변수(object형 변수)의 고유값 개수 확인하기

for i in obj_list:
    print(i, ' - ', data[i].nunique())

# workclass  -  8 (더미 변수로 바꾸기)
# education  -  16
# marital-status  -  7 (더미 변수로 바꾸기)
# occupation  -  14
# relationship  -  6 (더미 변수로 바꾸기)
# race  -  5 (더미 변수로 바꾸기)
# sex  -  2 (더미 변수로 바꾸기)
# native-country  -  41

범주형 변수는 고유값의 개수 만큼 더미 변수로 바꾸어 학습에 사용한다

      문제 : 고유값이 너무 많을 경우, 그만큼 변수가 많아진다

      해결 : 고유값이 너무 많은 변수는 추가 작업을 해주어 고유값의 개수 (nunique) 를 줄여준다

# 고유값이 10개 이상인 변수는 추가 작업을 해주자
for i in obj_list:
    if data[i].nunique() >= 10:
        print(i, ' - ', data[i].nunique())

# education  -  16
# occupation  -  14
# native-country  -  41

- 3-3. 전처리 : 범주형 변수 education 처리하기

value_counts(변수) : 해당 변수가 가진 고유값의 출현 빈도 확인

data['education'].value_counts()

# HS-grad         15784
# Some-college    10878
# Bachelors        8025
# Masters          2657
# Assoc-voc        2061
# 11th             1812
# Assoc-acdm       1601
# 10th             1389
# 7th-8th           955
# Prof-school       834
# 9th               756
# 12th              657
# Doctorate         594
# 5th-6th           509
# 1st-4th           247
# Preschool          83

일반적으로 범주형 변수를 숫자로 바꿀때는 원-핫 인코딩을 사용한다

      숫자의 순서가 데이터의 분석에 영향을 줄 수 있기 때문

하지만 education 변수는 학력이기 때문에 value끼리의 서열화가 이미 정해져 있다

이미 education-num 변수가 있다

      education-num 변수가 실제로 education 변수를 서열순으로 숫자화 했는지 확인해야 한다

np.sort(data['education-num'].unique()) # education-num의 고유값을 오름차순으로 정렬
# array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])

1부터 16까지의 education-num이 실제 education의  고유값과 1:1로 매칭되는지, 어떤 값과 매칭되는지 확인하기

data[data['education-num'] == 1] # education-num이 1인 row만 필터링해서 출력

data[data['education-num'] == 1]['education'].unique() # education-num이 1인 row만 필터링한 새로운 데이터에서 education 변수의 고유값
# array(['Preschool'], dtype=object)
# 고유값이 Preschool 하나
# 1 : Preschool
for i in np.sort(data['education-num'].unique()):
    print(i, ' - ', data[data['education-num'] == i]['education'].unique())
# 1  -  ['Preschool']
# 2  -  ['1st-4th']
# 3  -  ['5th-6th']
# 4  -  ['7th-8th']
# 5  -  ['9th']
# 6  -  ['10th']
# 7  -  ['11th']
# 8  -  ['12th']
# 9  -  ['HS-grad']
# 10  -  ['Some-college']
# 11  -  ['Assoc-voc']
# 12  -  ['Assoc-acdm']
# 13  -  ['Bachelors']
# 14  -  ['Masters']
# 15  -  ['Prof-school']
# 16  -  ['Doctorate']

education 변수는 education-num 변수로 1:1 매칭이 되어있기 때문에, 생략해도 된다

data.drop('education', axis = 1, inplace = True)

- 3-4. 전처리 : 범주형 변수 occupation 처리

data['occupation'].value_counts()
# Prof-specialty       6172
# Craft-repair         6112
# Exec-managerial      6086
# Adm-clerical         5611
# Sales                5504
# Other-service        4923
# Machine-op-inspct    3022
# Transport-moving     2355
# Handlers-cleaners    2072
# Farming-fishing      1490
# Tech-support         1446
# Protective-serv       983
# Priv-house-serv       242
# Armed-Forces           15

유사한 직업군끼리 이미 묶여 있다 + 직업에는 서열이 없다

→ 14개의 더미 변수로 원-핫 인코딩

- 3-5. 전처리 : 범주형 변수 native-country 처리

data['native-country'].value_counts()
# ...
# United-States                 43832
# ...

# 총 41개국
# 48842개 값 중 43832개가 United-States

1. United-States가 약 90%의 비중을 차지하고 있다 

      Unitied-States 이외의 국가를 Others로 묶어줄 수 있다

      장점 : 데이터 간소화

      문제 : 정보가 그만큼 줄어든다 (예측 모델에서 United-States가 아닌 국가 사이에 큰 차이가 없을 경우 사용 가능)

2. 지역별로 묶을 수 있다

      North America, South America, Asia 등

      해당 지역에 속한 국가끼리 유사성을 보일 경우 사용 가능

유사성 : 국가별종속변수(class)에 대한 평균을 내었을 때, 해당 지역의 국가들끼리 유사한 결과를 보일 때

      groupby(col 명).mean( ) : 해당 col을 기준으로 그룹화 하여 평균을 낸다

data.groupby('native-country').mean().sort_values('class')
# native-country 변수의 고유값을 기준으로 그룹화
# 각 변수별로 평균을 구한다
# 결과를 class 변수를 기준으로 오름차순 정리 (작은 값 -> 큰 값)

2-1. United-State와 Others로 묶을 수 있는가?

      United-State의 class 평균 값은 0.243977

      다른 나라들의 class 평균 값은 United-State보다 작기도 하고, 크기도 하다

      → Others로 묶으면 안된다

2-2. 지역별로 묶을 수 있는가?

      class 평균이 가장 높은 나라는 France

      Holand의 class 평균이 0 (관측치가 하나 뿐), Portugal의 class 평균이 0.179104

      유럽 국가들의 class 평균도 모여있지 않다

      → 지역별로 묶으면 안된다

결론 : 더미 변수를 사용하지 않고 (변수가 너무 많아진다) 국가명을 숫자로 변환하여 변수의 갯수를 유지한다

      일반적으로 범주형 데이터를 무작정 숫자로 치환하여 모델링할 때 문제점 : 숫자들을 연속적으로(크고 작음) 받아들임

      단, 트리 기반 모델을 사용할 경우에는 범주형 데이터를 숫자로 치환하여 사용해도 괜찮다

      → 트리 기반 모델은 연속된 숫자들도 연속적으로 받아들이지 않고, 일정 구간을 나누어 받아 들인다 

      → 트리가 충분히 깊어지면 범주형 변수를 숫자로 바꾸어도 문제가 되지 않는다

범주형 데이터를 숫자로 치환하는 다양한 방법

      1. 랜덤하게 번호를 부여하기 : United-State = 1, Peru = 2, Guatemala = 3, ...

      2. 국가별로 value_counts( )로 확인한 숫자를 부여하기 : 같은 수의 국가가 있을 경우 사용 불가

      3. class의 평균 값을 넣어주기

class의 평균 값을 국가별로 치환할 경우 발생하는 문제점

      예상하려는 목표값을 독립변수의 일환으로 반영한다

      답을 억지로 밀어 넣는게 되기 때문에 모델링 시 overfitting이 발생할 수 있다

# native-country 별 class의 평균값을 구하기
country_group = data.groupby('native-country').mean()['class']
country_group

# native-country
# Cambodia                      0.321429
# Canada                        0.346154
# China                         0.295082
# Columbia                      0.047059
# Cuba                          0.246377
# ...

country_group에는 국가별 class의 평균값이 저장되어 있다 → 기존 data에 새로운 변수로 합쳐야 한다

      현재 country_group에서 국가 이름은 변수가 아니라 인덱스로 되어 있다

      기존 data에 합치기 위해서는 국가 이름이 key가 되어야 한다

reset_index( ) : 인덱스를 변수로 빼낸다

country_group = country_group.reset_index()
country_group

# index		native-country	class
# 0		Cambodia	0.32142857142857145
# 1		Canada		0.34615384615384615
# 2		China		0.29508196721311475
# ...

merge( ) : 두 DataFrame을 합친다

data = data.merge(country_group, on = 'native-country', how = 'left')
# data DataFrame에 country_group DataFrame을 합친다
# 합치는 기준은 native-country 변수
# data DataFrame의 왼쪽에 country_group DataFrame을 합친다
data

문제 : 기존의 class 변수 외에 class_x (0과 1), class_y (국가별 class의 평균값) 변수가 생김

      기존의 data에도 class 변수가 있고, 추가하는 country_group에도 class 변수가 있기 때문

해결 : class_y 변수를 naitve-country 변수 대신 사용하고, class_x를 class 변수로 사용한다

      1. native-country 변수를 삭제하고, class_y의 변수명을 native-country로 바꾼다

      2. class_x를 class로 바꾼다

DataFrame.rename( ) : col명 또는 index명을 dictonary 맵핑을 통해 바꾼다

data.drop('native-country', axis = 1,inplace = True)
data.drop('class', axis = 1, inplace = True)
data = data.rename(columns = {'class_x':'class', 'class_y':'native-country'})
data

- 3-6. 전처리 : 결측치 처리 및 더미 변수 변환

결측치가 있는 변수 및 비율 확인하기

data.isna().mean()
# age               0.000000
# workclass         0.057307 - 결측치
# education-num     0.000000
# marital-status    0.000000
# occupation        0.057512 - 결측치
# relationship      0.000000
# race              0.000000
# sex               0.000000
# capital-gain      0.000000
# capital-loss      0.000000
# hours-per-week    0.000000
# class             0.000000
# native-country    0.017546 - 결측치
# dtype: float64

1. native-country : 우리가 이미 각 국가별 class의 평균 값으로 대체하였다

      우리가 이미 class의 평균 값으로 대체했기 때문에 mean( )이나 median( ) 으로 결측치를 채워도 문제가 되지 않는다

      트리 기반 모델에서는 완전 다른 값(주로 -9, -99)를 사용해도 문제가 되지 않아서 많이 사용한다

      (선형 모델에서 완전 다른 값을 채울 경우 데이터 왜곡이 발생하기 때문에 안된다)

data['native-country'] = data['native-country'].fillna(-99)

2. workclass : 범주형 변수로, nunique( )가 크지 않아서 더미 변수를 사용해 원-핫 인코딩을 할 예정

      결측치를 먼저 채우고 원-핫 인코딩을 진행

      범주형 변수이기 때문에 평균치로 해결할 수 없다 : 특정 텍스트를 채워주거나, dropna( )를 통해 해당 row를 제거

data['workclass'].value_counts() # 고유값 출현 빈도
# Private             33906 - 압도적
# Self-emp-not-inc     3862
# Local-gov            3136
# State-gov            1981
# Self-emp-inc         1695
# Federal-gov          1432
# Without-pay            21
# Never-worked           10
# Name: workclass, dtype: int64

      private의 비율이 압도적 → 결측치를 private로 채운다

data['workclass'] = data['workclass'].fillna('Private')

3. occupation : 범주형 변수로, 직업 사이에는 서열이 없기 때문에 더미 변수를 사용해 원-핫 인코딩을 할 예정

      범주형 변수이기 때문에 평균치로 해결할 수 없다 : 특정 텍스트를 채워주거나, dropna( )를 통해 해당 row를 제거

data['occupation'].value_counts() # 고유값 출현 빈도
# Prof-specialty       6172
# Craft-repair         6112
# Exec-managerial      6086
# Adm-clerical         5611
# Sales                5504
# Other-service        4923
# Machine-op-inspct    3022
# Transport-moving     2355
# Handlers-cleaners    2072
# Farming-fishing      1490
# Tech-support         1446
# Protective-serv       983
# Priv-house-serv       242
# Armed-Forces           15
# Name: occupation, dtype: int64

      어떤 특정 값의 출현 빈도가 높다고 할 수 없다 : 별도의 텍스트 'Unknown'으로 결측치를 채운다

data['occupation'] = data['occupation'].fillna('Unknown')

범주형 데이터의 결측치를 전부 처리하였기 때문에, data의 범주형 데이터를 원-핫 인코딩을 통해 더미 변수로 변환

data = pd.get_dummies(data, drop_first = True)

- 4. 모델링 

결정 트리 모델

      DecisionTreeRegressor : 연속형 변수용

      DecisionTreeClassifier : 범주형 변수용 

from sklearn.model_selection import train_test_split

X = data.drop('class', axis = 1)
y = data['class']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.4, random_state = 100)
# 데이터의 규모가 비교적 크기 때문에 test_size를 0.4 (40%)로 할당

from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier()
model.fit(X_train, y_train)

5. 예측

pred = model.predict(X_test)
pred

6. 평가

from sklearn.metrics import accuracy_score

accuracy_score(y_test, pred)
# 0.815222398525874

7. 결정 트리 매개변수 튜닝

결정 트리의 단점 : 트리의 깊이가 깊어질수록(수없이 많은 노드를 분류하여 모델을 만들수록) 오버피팅 발생 가능성 증가

1. 매개변수 튜닝 없이 학습하기 : 최대한 은 수준으로 트리를 생성

model = DecisionTreeClassifier()
model.fit(X_train, y_train)
train_pred = model.predict(X_train)
test_pred = model.predict(X_test) 
print('Train score : ', accuracy_score(y_train, train_pred)) # Train score :  0.9780242279474493
print('Test Score : ', accuracy_score(y_test, test_pred)) # Test Score :  0.8151712135947177

test set에 비해 train set의 점수가 압도적으로 높다 : 오버피팅 가능성 높다

2. 매개변수 튜닝하기 : max_depth

for i in range(1, 11):
    model = DecisionTreeClassifier(max_depth = i)
    model.fit(X_train, y_train)
    train_pred = model.predict(X_train)
    test_pred = model.predict(X_test)
    print('max_depth가 {}일 때'.format(i))
    print(' Train score : ', accuracy_score(y_train, train_pred))
    print(' Test score : ', accuracy_score(y_test, test_pred))

# max_depth가 1일 때
#  Train score :  0.7619860092134448
#  Test score :  0.7588166043916671
# max_depth가 2일 때
#  Train score :  0.8292100324176762
#  Test score :  0.8287864052822849
# max_depth가 3일 때
#  Train score :  0.8452141272820338
#  Test score :  0.8424527819010084
# max_depth가 4일 때
#  Train score :  0.8456577375874424
#  Test score :  0.8428110764191022
# max_depth가 5일 때
#  Train score :  0.8540180856509129
#  Test score :  0.8499769667809797
# max_depth가 6일 때
#  Train score :  0.8576352158334755
#  Test score :  0.8528945078568869
# max_depth가 7일 때
#  Train score :  0.8598532673605187
#  Test score :  0.8543788708604186
# max_depth가 8일 때
#  Train score :  0.8645964852414264
#  Test score :  0.8559144187951068
# max_depth가 9일 때
#  Train score :  0.8684866063811636
#  Test score :  0.8555049393458566
# max_depth가 10일 때
#  Train score :  0.8727520900870158
#  Test score :  0.8500793366432922

max_depth가 8일 때, train score가 낮으면서도 test score는 향상

- 8. 트리 그래프 그리기

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

model = DecisionTreeClassifier(max_depth = 8)
model.fit(X_train, y_train)

plt.figure(figsize = (30, 15)) # 그래프 크기 설정
plot_tree(model) # 트리 그래프 
plt.show() # 불필요한 문자 삭제

트리의 윗부분 (깊이 3단계) 까지만 출력하고, 폰트 크기를 조정하기

plt.figure(figsize = (30, 15))
plot_tree(model, max_depth = 3, fontsize = 15)
plt.show()

원래 변수 이름으로 그래프 그리기

plt.figure(figsize = (30, 15))
plot_tree(model, max_depth = 3, fontsize = 15, feature_names = X_train.columns)
# feature_names : 변수 이름을 추가하기
plt.show()

첫 번째 노드 분석하기

      gini : 지니 계수 (0.363)

      samples : 총 데이터 수 (29,305)

      value : 목표값이 0인 것이 22,330개, 1인 것이 6,975개

      marital-status_Married-civ-spouse <= 0.5 : 분류 기준 (결정 트리는 해당 기준을 만족하면 왼쪽, 아니면 오른쪽)

            marital-status_Married-civ-spouse는 marital-status에서 파생된 더미 변수

            혼인 상태가 Married-civ-spouse에 해당하지 않으면 (0 <= 0.5) 왼쪽으로 분류

            혼인 상태가 Married-civ-spouse에 해당하면 (not 1 <= 0.5) 오른쪽으로 분류

결정 트리 시각화

      1. plottree

      2. graphviz

      3. 결정 경계 시각화