티스토리 뷰
Decision Tree in R (분류분석)
> install.packages('rpart')
> library(rpart) # Decision Tree 분석 및 시각화를 위한 패키지
### 1. data sampling : train data set, test data set 분리
> library(doBy)
> train <- sampleBy(~Species, frac = 0.7, data = iris) # 70%의 랜덤 (train)데이터를 추출
# 70%에 포함된 train데이터의 행 번호를 추출하기 위한 사용자 함수 생성
> f1 <- function(x) {
as.numeric(strsplit(x, '\\.')[[1]][2]) # (.)을 문자로 인식시키기 위해 \\ 사용
}
> rn <- as.vector(sapply(rownames(train), f1)) # 70%에 포함된 train 데이터의 행 번호를 벡터형으로 저장
> test <- iris[-rn,] # 나머지 30%를 test 데이터로 저장
### 2. 모델 생성
> m <- rpart(Species ~., # formula : Y(종속변수) ~ X(설명변수) . 은 나머지 변수를 의미
data = train) # 모델 학습에 필요한 데이터 셋(70%의 train 데이터)
> m # 각 컬럼마다의 불순도 연산이 수행되므로 데이터가 커질수록 시간이 길어질 수 있음
n= 110
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 110 71 versicolor (0.3272727 0.3545455 0.3181818)
2) Petal.Length< 2.45 36 0 setosa (1.0000000 0.0000000 0.0000000) * # 그룹이 확정되는 경우 뒤에 별 표시
3) Petal.Length>=2.45 74 35 versicolor (0.0000000 0.5270270 0.4729730)
6) Petal.Width< 1.75 42 4 versicolor (0.0000000 0.9047619 0.0952381) *
7) Petal.Width>=1.75 32 1 virginica (0.0000000 0.0312500 0.9687500) *
# 질의, 총 건수, 오분류 건수, 대표 그룹 순으로 출력
# 위의 결과 해석
1) 루트 노드, 110건의 total 데이터, 71건의 오분류 데이터(불순도가 높음), versicolor는 대표 그룹 or (비율이 같을 경우)알파벳 빠른 순
2) 36건 중 0건의 오분류, 대표그룹은 setosa
=> 총 36개로 분류했는데 오분류가 0이 나온 것을 보아, setosa는 무조건 Petal.Length로 비교하면 된다는 결론.
3) 남은 74건 중 35건의 오분류(불순도가 높음), 대표 그룹은 versicolor
=> 그룹이 반/반 섞여있어 불순도가 높으므로 세부 분류 필요
6) 42건 중 4건의 오분류, 대표 그룹은 versicolor
=> 2번째 그룹이 비율이 높으므로, 2번째 그룹이 versicolor인 것으로 확인 가능
7) 32건 중 1개의 오분류, 대표 그룹은 virginica
=> 3번째 그룹이 virginica인 것을 확인 가능,
# 오분류 건수 > minsplit 라면 노드를 더 나눌 수 있습니다.
# Decision Tree를 통해 중요 컬럼 파악 가능. Petal.Length > Petal.Length > ...
### 3. 모델 평가
# 3-1) 새로운 데이터 셋(test)에 대한 예측력 확인
> val_var <- predict(m, newdata = test, type = 'class')
> sum(val_var == test$Species) / nrow(test) * 100
# 모델의 평가 점수 : TRUE 개수 / test data의 전체 건수 * 100
[1] 97.5 # iris 데이터는 이미 4개의 중요 컬럼이 선택되었기때문에 높은 결과가 나옴
# 3-2) overfit(과대적합) 확인 : train set score >>> test set score (train 데이터로의 결과가 10% 내외로 상당히 클 경우 재조정 필요)
# Y 인 종속변수가 데이터에 포함되어 있어도, 에러가 발생하지 않는 이유는 모델 자체가 필요한 설명 변수만 뽑아서 사용하기 때문
> val_var <- predict(m, newdata = train, type = 'class')
> sum(val_var == train$Species) / nrow(train) * 100
[1] 95.45455
### 4. 모델을 통한 예측
> iris[9,]
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
9 4.4 2.9 1.4 0.2 setosa
> new_data <- data.frame(Sepal.Length = 4.5, Sepal.Width = 3, # iris의 9번째 컬럼과 유사한 test data를 적용
Petal.Length = 1.2, Petal.Width = 0.3)
> predict(m, newdata = new_data, type = 'class') # predict(모델, test 데이터(data frame))
1
setosa
Levels: setosa versicolor virginica
### 5. 모델 시각화
> plot(m, compress = T, margin = 0.5)
> text(m, cex = 1.5)
# rpart.plot을 이용한 시각화
> install.packages("rpart.plot")
> library(rpart.plot)
> prp(m, type=4, extra=2, digits=3) # type : Tree의 타입 설정 (모든 노드에 레이블 표시)
+.
#. Decision Tree의 검정(적절한 가지 수 선택)
방법 1) m$cptable
> m$cptable # cp : 복잡성
CP nsplit rel error xerror xstd # nsplist : 가지의 분기 수 (nsplit+1 의 리프 노드가 생성)
1 0.5147059 0 1.00000000 1.0000000 0.07493313 # rel error : 오류율
2 0.4411765 1 0.48529412 0.6323529 0.07526037 # xerror : 교차 검증 오류
3 0.0100000 2 0.04411765 0.1029412 0.03764977 # xstd : 교차 검증 오류의 표준오차
=> 두 번의 분기(nsplit)가 발생할 때 오류율(rel error)이 가장 적으므로 두 번의 분기가 가장 적절
tree 생성 후 튜닝이 필요할 때, cp 값에 대한 범위 조정이 힘든데, 이 경우 cptable을 확인하여 'cp를 몇로 조정해봐야겠구나' 라는 힌트를 얻을 수 있음
ex) 만일 위 테이블을 보고 cp 를 0.01로 설정했는데 과대적합(overfit)이 발생한다면, 그 전 cp 값인 0.44로 조정해볼 수 있음
방법 2) plotcp(m)
> plotcp(m)
=> cp값으로 표현된 포인트들(위 결과는 4개의 포인트) 중 선 하단에서 선에 가까운 포인트(초록 원)의 y축(cp 값 : 약 0.01)을 선택합니다.
선 하단에 여러개의 포인트가 있다면 선 하단의 가장 왼쪽에 위치한 포인트의 cp를 선택하는 것이 바람직합니다.
선 하단에 포인터가 있다고 하더라도, 오른쪽으로 갈 수록 Error율은 적어지지만 트리의 크기가 커지게 됩니다.
트리의 크기가 커짐(=복잡해짐) -> 과대적합(overfit) 발생 -> 일반화하기 어려워짐
결론으로, 세 번째 포인트의 cp가 좋은 결과를 나타낸다고 확인 가능
#. 모델 정보 확인
> m$control
$minsplit
[1] 20
$minbucket
[1] 7
$cp
[1] 0.01
...
$maxdepth
[1] 30
...
모델의 control 컬럼을 통해 현재 모델에 적용된 모든 매개변수의 기본값을 리스트 형식으로 확인할 수 있습니다.
< minbucket >
- minbucket 값보다 오분류 건수가 많을 경우 depth가 추가 되겠죠? (=더 정확하게 분리)
- minbucket 가 작을수록 트리가 잘게 쪼개지므로 모델이 복잡해집니다. => train 데이터에 대한 높은 예측율 => but. 새로운 데이터 적용 시 예측율 낮아짐(오버핏 발생)
< maxdepth >
- maxdepth = 1 이면 중복 컬럼을 노드에 사용하지 않겠다는 것을 의미합니다.
- maxdepth 중복 허용 시(= 2 이상) 모델이 복잡해집니다. 이는, maxdepth 를 작게 설정함으로써 모델이 복잡해지는 것을 방지할 수 있습니다.
- 하지만, maxdepth 값이 어느정도 이상의 값에서는 증가효과가 없어져버립니다.
즉, maxdepth 를 5~10 이상 올려봐야 의미가 없겠죠, 그렇다면 1~3, 3~4 이렇게 짧은 구간을 테스트해야 좋습니다.
< cp >
- cp는 0~1 사이의 값을 가집니다. cp 의 범위는 m$cptable 을 통해 확인가능합니다.
#. 모델의 매개변수 변경
rpart.control(minsplit = 20, minbucket = round(minsplit/3), cp = 0.01, # minsplit 조정 시 minbucket이 자동으로 조정
maxcompete = 4, maxsurrogate = 5, usesurrogate = 2, xval = 10,
surrogatestyle = 0, maxdepth = 30, ...)
> m <- rpart(Species ~ . ,
+ data = iris)
> m # 초기 모델의 결과
n= 150
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) *
7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
> m <- rpart(Species ~ . ,
data = iris,
control = rpart.control(minbucket = 2))
> m # minbucket 값을 줄인 모델의 결과
n= 150
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259)
12) Petal.Length< 4.95 48 1 versicolor (0.00000000 0.97916667 0.02083333) *
13) Petal.Length>=4.95 6 2 virginica (0.00000000 0.33333333 0.66666667) *
7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
minbucket 값이 defualt 7에서 2로 줄어들게 되어, 트리가 잘게 쪼개진 것을 확인할 수 있습니다.
참고: KIC 캠퍼스 머신러닝기반의 빅데이터분석 양성과정
'R > Analysis' 카테고리의 다른 글
[R 분석] Random Forest 알고리즘 (0) | 2019.01.17 |
---|---|
[R 분석] Decision Tree 매개변수 튜닝 (0) | 2019.01.17 |
[R 분석] 조건부 추론 나무 (0) | 2019.01.16 |
[R 분석] 종속변수의 그룹(class) 별 데이터 개수 균등하게 맞추기 (0) | 2019.01.16 |
[R 분석] 지도학습을 위한 데이터 샘플링 (0) | 2019.01.16 |