본문 바로가기

머신러닝

05. 지도 학습 (분류) - 로지스틱 회귀 : 타이타닉 생존자 예측하기

[공부 자료 : Must Have 데싸노트의 실전에서 통하는 머신러닝]

# 로지스틱 회귀 (Logistic Regression)

- 로지스틱 회귀

로지스틱 회귀 : 두 가지로 나뉘는 분류 문제(이진 분류)를 다루는 알고리즘

      알고리즘의 근간을 선형 회귀 분석에 두고 있어서 선형 회귀 분석과 유사

      이름은 회귀이지만, 분류 문제에 사용되는 알고리즘

장점

      구현하기 쉽다

      계수(기울기)를 사용해 각 변수의 중요성을 쉽게 파악

단점

      선형 관계가 아닌 데이터에 대한 예측력이 떨어짐

이진 분류 문제에 선형 회귀가 아닌 로지즈스틱 회귀를 사용해야 하는 이유

https://ko.wikipedia.org/wiki/%EB%A1%9C%EC%A7%80%EC%8A%A4%ED%8B%B1_%ED%9A%8C%EA%B7%80

 

- 로지스틱 함수

# 로지스틱 회귀 - 타이타닉 생존자 예측

- 1. 데이터 불러오기

import pandas as pd

file_url = 'https://media.githubusercontent.com/media/musthave-ML10/data_source/main/titanic.csv'
data = pd.read_csv(file_url)

- 2. 데이터 확인하기

data.head()
data.info()
# RangeIndex: 889 entries, 0 to 888
# Data columns (total 9 columns):
#  #   Column    Non-Null Count  Dtype  
# ---  ------    --------------  -----  
#  0   Pclass    889 non-null    int64  
#  1   Name      889 non-null    object - 문자형 데이터
#  2   Sex       889 non-null    object - 문자형 데이터
#  3   Age       889 non-null    float64
#  4   SibSp     889 non-null    int64  
#  5   Parch     889 non-null    int64  
#  6   Ticket    889 non-null    object - 문자형 데이터
#  7   Embarked  889 non-null    object - 문자형 데이터
#  8   Survived  889 non-null    int64

info( ) 확인 결과 : 결측치는 없지만, 문자형 데이터가 있다

      Name, Sex, Ticket, Embarked

data.describe()

describe( ) 확인 결과

      Parch의 경우, min / 25% / 50% / 75%는 전부 0인데 max만 6.0 

      Parch의 대부분 데이터 값은 0이다 (동반 승객 없음)

데이터 확인 - 변수 간 상관관계 분석하기

corr( ) : 변수 간 상관관계 출력

      0.2 이하 : 상관관계 낮음

      0.2 ~ 0.4 : 낮은 수준의 상관관계

      0.4 ~ 0.6 : 중간 수준의 상관관계

      0.6 ~ 0.8 : 높은 수준의 상관관계

      0.8 이상 : 매우 높은 수준의 상관관계

data.corr()
# 상관관계 값이 클수록 강하다
# Parch(동반 부모/자녀 수)와 SibSp(동반 형제/배우자 수)는 상관관계가 가장 강하다
# 자료형이 object인 변수는 자동으로 빠진다
# 		Pclass		Age		SibSp		Parch		Survived
# Pclass	1.000000	-0.336512	0.081656	0.016824	-0.335549
# Age		-0.336512	1.000000	-0.232543	-0.171485	-0.069822
# SibSp		0.081656	-0.232543	1.000000	0.414542	-0.034040
# Parch		0.016824	-0.171485	0.414542	1.000000	0.083151
# Survived	-0.335549	-0.069822	-0.034040	0.083151	1.000000

heatmap( ) : 히트맵을 통해 변수 간 상관관계 파악

      히트맵 색상 배열 검색 : seaborn plaette

import matplotlib.pyplot as plt
import seaborn as sns

sns.heatmap(data.corr())
plt.show()

sns.heatmap(data.corr(), cmap = 'coolwarm') # cmap : 색상 변경
plt.show()

sns.heatmap(data.corr(), cmap = 'coolwarm', vmin = -1, vmax =1, annot = True) # 색상 범위 + 수치
plt.show()

- 3.1.  데이터 전처리 : 더미 변수와 원-핫 인코딩

자료형이 문자(object)인 변수 처리하기

      머신러닝 알고리즘에서는 문자열로 된 데이터를 이해하지 못한다

(1) 각 값 (특정 문자) 를 숫자로 대체하기

      봄(1), 여름(2), 가을(3), 겨울(4)

      선형 모델에 이 방법을 사용할 경우, 숫자가 상대적인 서열로 인식된다

(2) 더미 변수를 사용하기 (원-핫 인코딩)

      남자(sex_male = 1, sex_female = 0), 여자(sex_male = 0, sex_female = 1)

      sex_male, sex_female 컬럼 : 더미 변수

원-핫 인코딩의 문제점 : 종류가 많을 경우 그만큼 col을 만들어야 한다

# 문자 데이터를 갖는 변수의 고유값 확인하기
data['Name'].nunique() # 889
data['Sex'].nunique() # 2
data['Ticket'].nunique() # 680
data['Embarked'].nunique() # 3

drop([col_1, col_2, ...], axis = 1) : DataFrame에서 해당 column을 삭제

      Name과 Ticket은 고유값이 너무 많지만, 종속변수에 영향을 주지 않는다 → 삭제

data = data.drop(['Name', 'Ticket'], axis = 1)
data.head()

get_dummies(data, columns = [col_1, col_2, ...]) : 원-핫 인코딩

      남은 문자 데이터를 갖는 변수(Sex, Embarked)에 대하여 더미 변수를 만든다

pd.get_dummies(data, columns = ['Sex', 'Embarked'])

원-핫 인코딩을 통해 더미 변수를 만들면 해당 변수의 고유값만큼 더미 변수를 만든다

      Sex_male을 알면 Sex_female은 자동으로 결정된다

      Embarked_Q와 Embarked_S를 알면 Embarekd_C는 자동으로 결정된다

get_dummies(data, columns = [col_1, col_2, ...], drop_first = True) : 첫 번째 더미변수를 자동 삭제

data = pd.get_dummies(data, columns = ['Sex', 'Embarked'], drop_first = True)

 

- 3.2. 데이터 전처리 : train_test_split

from sklearn.model_selection import train_test_split

X = data.drop('Survived', axis = 1) # data에서 Survived를 뺀게 독립변수
y = data['Survived'] # data의 Survived가 종속변수

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 100)

- 4. 모델링

from sklearn.linear_model import LogisticRegression

model = LogisticRegression()
model.fit(X_train, y_train)

- 5. 예측하기

pred = model.predict(X_test)
pred 
# model.predict() : 예측한 결과를 0.5를 기준으로 1과 0으로 변환
# array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1,
#        0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
#        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,
#        1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
#        0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,
#        1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1,
#        0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
#        0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0,
#        0, 0])

pred_proba = model.predict_proba(X_test)
pred_proba
# model.predict_proba : 로지스틱 회귀를 통해 진짜로 분석한 결과 [0일 확률, 1일 확률]
# array([[0.92551972, 0.07448028],
#        [0.37708123, 0.62291877],
#        [0.92277744, 0.07722256],
#        [0.45019258, 0.54980742],
#        [0.88480543, 0.11519457],
#        [0.18748259, 0.81251741],
#        [0.90706312, 0.09293688],
#        [0.79493606, 0.20506394],
#        [0.89888332, 0.10111668],
#        [0.3494419 , 0.6505581 ],
#        [0.37120425, 0.62879575],
#        [0.44706889, 0.55293111],
#        [0.64951334, 0.35048666],
#        [0.26773463, 0.73226537],
#        [0.08265115, 0.91734885],
#        [0.88509364, 0.11490636],
#        [0.83830721, 0.16169279],
#        [0.60594142, 0.39405858],
#        [0.67924665, 0.32075335],
#        [0.88480543, 0.11519457],
#        [0.25098448, 0.74901552],
#        [0.25399704, 0.74600296],
#        [0.92538289, 0.07461711],
#        [0.95908948, 0.04091052],
#        [0.05454845, 0.94545155],
#        [0.1865422 , 0.8134578 ],
#        [0.2913128 , 0.7086872 ],
#        [0.90010958, 0.09989042],
#        [0.93769962, 0.06230038],
#        [0.67181047, 0.32818953],
#        [0.07905091, 0.92094909],
#        [0.29185877, 0.70814123],
#        [0.97121737, 0.02878263],
#        [0.95421161, 0.04578839],
#        [0.87202375, 0.12797625],
#        [0.86750169, 0.13249831],
#        [0.35307916, 0.64692084],
#        [0.51574066, 0.48425934],
#        [0.94423623, 0.05576377],
#        [0.6439179 , 0.3560821 ],
#        [0.73636615, 0.26363385],
#        [0.92551972, 0.07448028],
#        [0.44915273, 0.55084727],
#        [0.6934977 , 0.3065023 ],
#        [0.88067321, 0.11932679],
#        [0.94434057, 0.05565943],
#        [0.92538289, 0.07461711],
#        [0.90010958, 0.09989042],
#        [0.93123745, 0.06876255],
#        [0.93071254, 0.06928746],
#        [0.89212146, 0.10787854],
#        [0.86284498, 0.13715502],
#        [0.93580721, 0.06419279],
#        [0.72282674, 0.27717326],
#        [0.88881262, 0.11118738],
#        [0.06606613, 0.93393387],
#        [0.95106039, 0.04893961],
#        [0.58672799, 0.41327201],
#        [0.1528808 , 0.8471192 ],
#        [0.37120425, 0.62879575],
#        [0.87202375, 0.12797625],
#        [0.95757652, 0.04242348],
#        [0.85311876, 0.14688124],
#        [0.25475842, 0.74524158],
#        [0.21908075, 0.78091925],
#        [0.91967866, 0.08032134],
#        [0.24328591, 0.75671409],
#        [0.9166796 , 0.0833204 ],
#        [0.62400055, 0.37599945],
#        [0.95243526, 0.04756474],
#        [0.55932145, 0.44067855],
#        [0.52829904, 0.47170096],
#        [0.15324777, 0.84675223],
#        [0.90010958, 0.09989042],
#        [0.9914523 , 0.0085477 ],
#        [0.90010958, 0.09989042],
#        [0.73422093, 0.26577907],
#        [0.90010958, 0.09989042],
#        [0.65854683, 0.34145317],
#        [0.12193454, 0.87806546],
#        [0.80595185, 0.19404815],
#        [0.84858579, 0.15141421],
#        [0.88509364, 0.11490636],
#        [0.05676314, 0.94323686],
#        [0.03864687, 0.96135313],
#        [0.87686787, 0.12313213],
#        [0.87373624, 0.12626376],
#        [0.03835144, 0.96164856],
#        [0.90706312, 0.09293688],
#        [0.25998027, 0.74001973],
#        [0.24884386, 0.75115614],
#        [0.87233933, 0.12766067],
#        [0.73422093, 0.26577907],
#        [0.88509364, 0.11490636],
#        [0.82850745, 0.17149255],
#        [0.6439179 , 0.3560821 ],
#        [0.86750169, 0.13249831],
#        [0.38850399, 0.61149601],
#        [0.91357908, 0.08642092],
#        [0.92551972, 0.07448028],
#        [0.46538808, 0.53461192],
#        [0.85002327, 0.14997673],
#        [0.91357908, 0.08642092],
#        [0.8174557 , 0.1825443 ],
#        [0.90010958, 0.09989042],
#        [0.63112636, 0.36887364],
#        [0.88509364, 0.11490636],
#        [0.86750169, 0.13249831],
#        [0.21099528, 0.78900472],
#        [0.89269732, 0.10730268],
#        [0.26773463, 0.73226537],
#        [0.34501054, 0.65498946],
#        [0.45855693, 0.54144307],
#        [0.63007144, 0.36992856],
#        [0.86284498, 0.13715502],
#        [0.88881262, 0.11118738],
#        [0.79858072, 0.20141928],
#        [0.43772157, 0.56227843],
#        [0.87686787, 0.12313213],
#        [0.90010958, 0.09989042],
#        [0.85347313, 0.14652687],
#        [0.64133848, 0.35866152],
#        [0.25971571, 0.74028429],
#        [0.28664966, 0.71335034],
#        [0.88509364, 0.11490636],
#        [0.31660217, 0.68339783],
#        [0.2106653 , 0.7893347 ],
#        [0.71020326, 0.28979674],
#        [0.90066433, 0.09933567],
#        [0.56089534, 0.43910466],
#        [0.09818447, 0.90181553],
#        [0.37120425, 0.62879575],
#        [0.79175549, 0.20824451],
#        [0.52720453, 0.47279547],
#        [0.82878906, 0.17121094],
#        [0.91357908, 0.08642092],
#        [0.88509364, 0.11490636],
#        [0.64951334, 0.35048666],
#        [0.427921  , 0.572079  ],
#        [0.91395841, 0.08604159],
#        [0.0526362 , 0.9473638 ],
#        [0.94465497, 0.05534503],
#        [0.37120425, 0.62879575],
#        [0.19674462, 0.80325538],
#        [0.0436084 , 0.9563916 ],
#        [0.90010958, 0.09989042],
#        [0.77828343, 0.22171657],
#        [0.91357908, 0.08642092],
#        [0.04189362, 0.95810638],
#        [0.73422093, 0.26577907],
#        [0.80989304, 0.19010696],
#        [0.90010958, 0.09989042],
#        [0.87641346, 0.12358654],
#        [0.90010958, 0.09989042],
#        [0.93324328, 0.06675672],
#        [0.37120425, 0.62879575],
#        [0.26600736, 0.73399264],
#        [0.04879379, 0.95120621],
#        [0.46953101, 0.53046899],
#        [0.48746283, 0.51253717],
#        [0.19927935, 0.80072065],
#        [0.68098294, 0.31901706],
#        [0.91395841, 0.08604159],
#        [0.45795401, 0.54204599],
#        [0.52889345, 0.47110655],
#        [0.38760663, 0.61239337],
#        [0.36748634, 0.63251366],
#        [0.13800541, 0.86199459],
#        [0.90010958, 0.09989042],
#        [0.09303553, 0.90696447],
#        [0.05476556, 0.94523444],
#        [0.40259165, 0.59740835],
#        [0.45855693, 0.54144307],
#        [0.95757652, 0.04242348],
#        [0.87202375, 0.12797625],
#        [0.9524331 , 0.0475669 ],
#        [0.90691164, 0.09308836],
#        [0.50836634, 0.49163366]])

- 6. 평가하기

이진 분류 (binary classification) : Survived는 살았는가/죽었는가 

      RMSE는 이진 분류의 평가에 적합하지 않다

이진 분류 평가 지표

      1. 정확도 (accuracy)

      2. 오차 행렬

      3. 특이도

      4. AUC

      5. 정밀도 (precision)

      6. 재현율 (recall)

      7. F1 Score

      8. 민감도

      9. 특이도

정확도 : 가장 간단

      test set 100개 중 90개를 맞추면 정확도 0.9

from sklearn.metrics import accuracy_score

accuracy_score(y_test, pred)
# 0.7808988764044944

- 7. 로지스틱 회귀 모델 분석하기

model.coef_
# array([[-1.18229807, -0.03992439, -0.32137838,  0.00798081, -2.56862996,
#         -0.07847763, -0.23534439]])

model.coef_[0]
# array([-1.18229807, -0.03992439, -0.32137838,  0.00798081, -2.56862996,
#        -0.07847763, -0.23534439])

pd.Series(model.coef_[0], index = X.columns)
# Pclass       -1.182298
# Age          -0.039924
# SibSp        -0.321378
# Parch         0.007981
# Sex_male     -2.568630
# Embarked_Q   -0.078478
# Embarked_S   -0.235344
# dtype: float64

Survived가 1이여야 생존이다

      Pclass, Age, SibSp, Sex_male, Embarked_Q, Embarked_S가 낮을수록 생존(1)에 유리하다 (w < 0)

      Parch가 높을수록 생존(1)에 유리하다 (w > 0)

로지스틱 회귀에서는 선형 회귀처럼 단순히 계수를 기울기처럼 곱하여 수식을 만들면 안된다

# 피처 엔지니어링

피처 엔지니어링 : 기존의 데이터를 활용하여 더 나은 변수를 만드는 것

      더미 변수를 만드는 원-핫 인코딩도 피처 엔지니어링에 포함된다

- 선형 모델의 다중공산성(Multicollinearity) 문제

선형 회귀 분석, 로지스틱 회귀 분석 등 선형 모델에서 발생하는 문제
      독립변수 사이에 상관관계가 높을 때 발생

독립변수 A, B와 종속변수 y가 있을 때 A와 B가 모두 y를 양의 방향으로 이끈다면

      만약 A와 B의 상관관계가 높을 경우, y가 증가한 이유가 A 때문인지 B 때문인지 명확히 설명할 수 없다

다중공산성 문제를 해결하는 방법

      1. 상관관계가 높은 변수 중 하나를 제거

      2. 둘을 모두 포괄하는 새로운 변수를 생성

      3. PCA(Principal Component Analysis, 주성분 분석)와 같은 차원 축소를 수행

타이타닉 데이터에서 상관관계가 높은 ParchSibSp를 새로운 변수로 만들기 : 피처 엔지니어링

      Parch : 동반 부모/자식 수

      SibSp : 동반 형제/자매/배우자 수

# Parch와 SibSp를 합친, Family col 만들기
data['Family'] = data['SibSp'] + data['Parch']
data.drop(['SibSp', 'Parch'], axis = 1, inplace = True) # SibSp col, Parch col을 삭제하고 data에 반영하기
data.head()

X = data.drop('Survived', axis = 1)
y = data['Survived']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 100)

model = LogisticRegression()
model.fit(X_train, y_train)

pred = model.predict(X_test)

accuracy_score(y_test, pred)
# 0.7921348314606742

피처 엔지니어링을 통해 정확도가 0.78에서 0.79로 소폭 상승