티스토리 뷰

반응형

Decision Tree in R (분류분석)



> install.packages('rpart')    

> library(rpart)               # Decision Tree 분석 및 시각화를 위한 패키지


### 1. data sampling : train data set, test data set 분리

> library(doBy)  

> train <- sampleBy(~Species, frac = 0.7, data = iris)    # 70%의 랜덤 (train)데이터를 추출


# 70%에 포함된 train데이터의 행 번호를 추출하기 위한 사용자 함수 생성 

> f1 <- function(x) {

     as.numeric(strsplit(x, '\\.')[[1]][2])                        # (.)을 문자로 인식시키기 위해 \\ 사용

   }

> rn <- as.vector(sapply(rownames(train), f1))            # 70%에 포함된 train 데이터의 행 번호를 벡터형으로 저장 

> test <- iris[-rn,]                                                  # 나머지 30%를 test 데이터로 저장



### 2. 모델 생성

> m <- rpart(Species ~.,  # formula : Y(종속변수) ~ X(설명변수)      . 은 나머지 변수를 의미

           data = train)  # 모델 학습에 필요한 데이터 셋(70%의 train 데이터)

> m                         # 각 컬럼마다의 불순도 연산이 수행되므로 데이터가 커질수록 시간이 길어질 수 있음

 

n= 110 


node), split, n, loss, yval, (yprob)

      * denotes terminal node


1) root 110 71 versicolor (0.3272727 0.3545455 0.3181818)  

  2) Petal.Length< 2.45 36  0 setosa (1.0000000 0.0000000 0.0000000) *   그룹이 확정되는 경우 뒤에 별 표시

  3) Petal.Length>=2.45 74 35 versicolor (0.0000000 0.5270270 0.4729730)  

    6) Petal.Width< 1.75 42  4 versicolor (0.0000000 0.9047619 0.0952381) *

    7) Petal.Width>=1.75 32  1 virginica (0.0000000 0.0312500 0.9687500) *

# 질의, 총 건수, 오분류 건수, 대표 그룹 순으로 출력

# 위의 결과 해석

1) 루트 노드, 110건의 total 데이터, 71건의 오분류 데이터(불순도가 높음), versicolor는 대표 그룹 or (비율이 같을 경우)알파벳 빠른 순

  2) 36건 중 0건의 오분류, 대표그룹은 setosa

      => 총 36개로 분류했는데 오분류가 0이 나온 것을 보아, setosa는 무조건 Petal.Length로 비교하면 된다는 결론.  

  3) 남은 74건 중 35건의 오분류(불순도가 높음), 대표 그룹은 versicolor

      => 그룹이 반/반 섞여있어 불순도가 높으므로 세부 분류 필요

    6) 42건 중 4건의 오분류, 대표 그룹은 versicolor

      => 2번째 그룹이 비율이 높으므로, 2번째 그룹이 versicolor인 것으로 확인 가능

    7) 32건 중 1개의 오분류, 대표 그룹은 virginica

      => 3번째 그룹이 virginica인 것을 확인 가능, 

 

# 오분류 건수 > minsplit 라면 노드를 더 나눌 수 있습니다.

# Decision Tree를 통해 중요 컬럼 파악 가능. Petal.Length > Petal.Length > ... 



### 3. 모델 평가

# 3-1) 새로운 데이터 셋(test)에 대한 예측력 확인

> val_var <- predict(m, newdata = test, type = 'class') 

> sum(val_var == test$Species) / nrow(test) * 100        

# 모델의 평가 점수 : TRUE 개수 / test data의 전체 건수 * 100

[1] 97.5           # iris 데이터는 이미 4개의 중요 컬럼이 선택되었기때문에 높은 결과가 나옴


# 3-2) overfit(과대적합) 확인 : train set score >>> test set score  (train 데이터로의 결과가 10% 내외로 상당히 클 경우 재조정 필요)

# Y 인 종속변수가 데이터에 포함되어 있어도, 에러가 발생하지 않는 이유는 모델 자체가 필요한 설명 변수만 뽑아서 사용하기 때문

> val_var <- predict(m, newdata = train, type = 'class') 

> sum(val_var == train$Species) / nrow(train) * 100 

[1] 95.45455



### 4. 모델을 통한 예측

> iris[9,]     

Sepal.Length Sepal.Width Petal.Length Petal.Width Species

9          4.4         2.9          1.4         0.2  setosa

> new_data <- data.frame(Sepal.Length = 4.5, Sepal.Width = 3,      # iris의 9번째 컬럼과 유사한 test data를 적용

                                     Petal.Length = 1.2, Petal.Width = 0.3)

> predict(m, newdata = new_data, type = 'class')    # predict(모델, test 데이터(data frame))

     1 

setosa 

Levels: setosa versicolor virginica



### 5. 모델 시각화

> plot(m, compress = T, margin = 0.5)

> text(m, cex = 1.5)


# rpart.plot을 이용한 시각화

> install.packages("rpart.plot")

> library(rpart.plot)      

> prp(m, type=4, extra=2, digits=3)   # type : Tree의 타입 설정 (모든 노드에 레이블 표시)

                                                   # extra : Tree의 추가적인 정보 설정 (각 노드별 관측값 대비 올바르게 분류된 데이터의 비율 표시(올바른 분류의 수/해당 노드 전체 수))




+. 

#. Decision Tree의 검정(적절한 가지 수 선택)


방법 1) m$cptable       


> m$cptable                                                                      # cp : 복잡성

          CP    nsplit   rel error        xerror         xstd                    # nsplist : 가지의 분기 수 (nsplit+1 의 리프 노드가 생성)

 1  0.5147059   0   1.00000000   1.0000000   0.07493313              # rel error : 오류율

 2  0.4411765   1   0.48529412   0.6323529   0.07526037              # xerror : 교차 검증 오류

 3  0.0100000   2   0.04411765   0.1029412  0.03764977              # xstd : 교차 검증 오류의 표준오차


=> 두 번의 분기(nsplit)가 발생할 때 오류율(rel error)이 가장 적으므로 두 번의 분기가 가장 적절

tree 생성 후 튜닝이 필요할 때, cp 값에 대한 범위 조정이 힘든데, 이 경우 cptable을 확인하여 'cp를 몇로 조정해봐야겠구나' 라는 힌트를 얻을 수 있음

ex) 만일 위 테이블을 보고 cp 를 0.01로 설정했는데 과대적합(overfit)이 발생한다면, 그 전 cp 값인 0.44로 조정해볼 수 있음



방법 2plotcp(m)


> plotcp(m)


=> cp값으로 표현된 포인트들(위 결과는 4개의 포인트)선 하단에서 선에 가까운 포인트(초록 원)의 y축(cp 값 : 약 0.01)을 선택합니다.

선 하단에 여러개의 포인트가 있다면 선 하단의 가장 왼쪽에 위치한 포인트의 cp를 선택하는 것이 바람직합니다.

선 하단에 포인터가 있다고 하더라도, 오른쪽으로 갈 수록 Error율은 적어지지만 트리의 크기가 커지게 됩니다. 

트리의 크기가 커짐(=복잡해짐) -> 과대적합(overfit) 발생 -> 일반화하기 어려워짐

결론으로, 세 번째 포인트의 cp가 좋은 결과를 나타낸다고 확인 가능




#. 모델 정보 확인


> m$control

$minsplit

[1] 20

$minbucket

[1] 7

$cp

[1] 0.01

...

$maxdepth

[1] 30

...


모델의 control 컬럼을 통해 현재 모델에 적용된 모든 매개변수의 기본값을 리스트 형식으로 확인할 수 있습니다.


< minbucket >

- minbucket 값보다 오분류 건수가 많을 경우 depth가 추가 되겠죠? (=더 정확하게 분리)

- minbucket 가 작을수록 트리가 잘게 쪼개지므로 모델이 복잡해집니다. => train 데이터에 대한 높은 예측율 => but. 새로운 데이터 적용 시 예측율 낮아짐(오버핏 발생)


maxdepth >

- maxdepth = 1 이면 중복 컬럼을 노드에 사용하지 않겠다는 것을 의미합니다. 

- maxdepth 중복 허용 시(= 2 이상) 모델이 복잡해집니다. 이는, maxdepth 를 작게 설정함으로써 모델이 복잡해지는 것을 방지할 수 있습니다.

- 하지만, maxdepth 값이 어느정도 이상의 값에서는 증가효과가 없어져버립니다. 

즉, maxdepth 를  5~10 이상 올려봐야 의미가 없겠죠, 그렇다면 1~3, 3~4 이렇게 짧은 구간을 테스트해야 좋습니다.


< cp >

- cp는 0~1 사이의 값을 가집니다. cp 의 범위는 m$cptable 을 통해 확인가능합니다. 




#. 모델의 매개변수 변경


rpart.control(minsplit = 20, minbucket = round(minsplit/3), cp = 0.01,   # minsplit 조정 시 minbucket이 자동으로 조정

                  maxcompete = 4, maxsurrogate = 5, usesurrogate = 2, xval = 10,

                  surrogatestyle = 0, maxdepth = 30, ...)


> m <- rpart(Species ~ . ,

+            data = iris)

> m                                       # 초기 모델의 결과

n= 150 


node), split, n, loss, yval, (yprob)

      * denotes terminal node


1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  

  2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *

  3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  

    6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *

    7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *


> m <- rpart(Species ~ . ,

           data = iris,

           control = rpart.control(minbucket = 2))

> m                                          # minbucket 값을 줄인 모델의 결과


n= 150 


node), split, n, loss, yval, (yprob)

      * denotes terminal node


 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  

   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *

   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  

     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259)  

      12) Petal.Length< 4.95 48   1 versicolor (0.00000000 0.97916667 0.02083333) *

      13) Petal.Length>=4.95 6   2 virginica (0.00000000 0.33333333 0.66666667) *

     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *


minbucket 값이 defualt 7에서 2로 줄어들게 되어, 트리가 잘게 쪼개진 것을 확인할 수 있습니다. 




참고: KIC 캠퍼스 머신러닝기반의 빅데이터분석 양성과정

반응형
댓글
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday