LDA (Linear Discriminant Analysis)
LDA는 선형 판별 분석법으로 불리며, PCA와 매우 유사하다.
LDA는 PCA와 유사하게 입력 데이터 셋을 저차원 공간에 투영해 차원을 축소하는 기법이지만,
중요한 차이는 LDA는 지도학습의 분류에서 사용하기 쉽도록 개별 클래스를 분별할 수 있는 기준을 최대한 유지하면서 결정 값 클래스를 최대한으로 분리 할 수 있는 축을 찾는다.
LDA는 특정 공간상에서 클래스 분리를 최대화 하는 축을 찾기 위해 클래스 간 분산과 클래스 내부 분산의 비율을 최대화 하는 방식으로 차원을 축소한다.
즉 , 클래스 간 분산은 최대한 크게 가져가고, 클래스 내부의 분산은 최대한 작게 가져가는 방식이다.
다음 그림은 좋은 클래스 분리를 위해 클래스 간 분산이 크고 클래스 내부 분산이 작은 것을 표현한 것이다.
일반적으로 LDA를 구하는 스텝은 PCA와 유사하나 가장 큰 차이점은 공분산 행렬이 아니라 위에 설명한 클래스 간 분산과 클래스 내부 분산 행렬을 생성한 뒤,
이 행렬에 기반해 고유벡터를 구하고 입력 데이터를 투영한다는 점이다.
LDA를 구하는 스텝은 다음과 같다.
클래스 내부에 클래스 간 분산 행렬을 구한다. 이 두 개의 행렬은 입력 데이터의 결정 값 클래스별로 개별 피처의 평균 벡터를 기반으로 구한다.
클래스 내부 분산 행렬을 S_w , 클래스 간 분산 행렬을 S_B 라고 한다면 다음 식으로 두 행렬을 고유 벡터로 분해할 수 있다.
고유값이 가장 큰 순으로 K개 추출한다.
고유값이 가장 큰 순으로 추출된 고유 벡터를 이용해 새롭게 입력 데이터를 변환한다.
붓꽃 데이터 셋에 LDA 적용하기
붓꽃 데이터 셋을 사이킷런의 LDA를 이용해 변환하고, 그 결과를 품종별로 시각화해 보자.
사이킷런은 LDA를 LinearDiscriminantAnalysis 클래스로 제공한다. 붓꽃 데이터 셋을 로드하고 표준 정규 분포로 스케일링 한다.
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
iris = load_iris()
iris_scaled = StandardScaler().fit_transform(iris.data)
lda = LinearDiscriminantAnalysis(n_components=2)
lda.fit(iris_scaled, iris.target)
iris_lda = lda.transform(iris_scaled)
print(iris_lda.shape)
(150, 2)
2개의 컴포넌트로 붓꽃 데이터를 LDA로 변환하겠습니다.
PCA와는 다르게 LDA에서 한 가지 유의해야할 점은 LDA는 실제로는 PCA와 다르게 비지도학습이 아닌 지도학습이라는 것이다.
즉, 클래스 값이 변환시에 필요하다는 의미다.
다음 LDA 객체에 fit() 메서드를 호출 할 때 결정값이 입력됐음에 유의 하자
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
lda_columns=['lda_component_1','lda_component_2']
irisDF_lda = pd.DataFrame(iris_lda,columns=lda_columns)
irisDF_lda['target']=iris.target
#setosa는 세모, versicolor는 네모, virginica는 동그라미로 표현
markers=['^', 's', 'o']
#setosa의 target 값은 0, versicolor는 1, virginica는 2. 각 target 별로 다른 shape으로 scatter plot
for i, marker in enumerate(markers):
x_axis_data = irisDF_lda[irisDF_lda['target']==i]['lda_component_1']
y_axis_data = irisDF_lda[irisDF_lda['target']==i]['lda_component_2']
plt.scatter(x_axis_data, y_axis_data, marker=marker,label=iris.target_names[i])
plt.legend(loc='upper right')
plt.xlabel('lda_component_1')
plt.ylabel('lda_component_2')
plt.show()