티스토리 뷰
K-NN 알고리즘 적용 및 매개변수 튜닝
K-NN(K-Nearest Neighbors) 알고리즘은 새로운 관측치와 기존 데이터와의 거리 연산를 통한 분류가 목적이므로 train data / test data가 동시에 들어가는게 특징입니다.
또한, 알고리즘 자체에 predict() 기능을 보유하고 있어서 예측 및 평가에 predict 함수를 사용하지 않아도 자체적으로 수행해줍니다.
install.packages("class")
library(class)
knn(train, # 모델 평가용 데이터 중 예측 변수
test, # 예측용 데이터 (예측을 훈련과 동시에 가능)
cl, # 분류(class) 변수
k=n, # k 설정 (근접한 k개의 데이터까지 확인)
prob=TRUE) # 분류 비율 설명 (근접한 k개 데이터의 그룹 비율, 어느 그룹이 많이 포함되었는지)
### 데이터 불러오기
> cancer <- read.csv("cancer.csv") # cancer data는 174개의 row, 31개의 column
> table(cancer$diagnosis)
Benign Malignant # 그룹(class) 별 균등 작업 필요
357 212
> library(caret)
> cancer <- upSample(cancer[,-c(1,2)], cancer$diagnosis, yname = 'diagnosis') # upSampling
> table(cancer$diagnosis)
Benign Malignant
357 357
### 1. randomForest 모델을 활용한 핵심 설명변수 선택
# knn은 feature selection이 모델에 자동 적용되지 않으므로, 정확도(score) 상승을 위해 핵심 변수 선택 필요
> library(randomForest) # randomForeset 모델에 데이터를 적용하여 핵심 변수 확인
> forest_m <- randomForest(cancer$diagnosis ~ ., data=cancer)
> impor_val <- names(forest_m$importance[order(forest_m$importance, decreasing = T),][1:5]) # 중요도 Top 5 변수 추출
> impor_val # cancer data의 핵심 변수
[1] "perimeter_worst" "concave_points_worst" "radius_worst" "area_worst"
[5] "concave_points_mean"
> cancer2 <- cancer[,c(impor_val, 'diagnosis')] # 31개의 설명변수 중 선택한 핵심 변수만 추출
> cancer2
perimeter_worst concave_points_worst radius_worst area_worst concave_points_mean diagnosis
1 99.70 0.128800 15.110 711.2 0.047810 Benign
2 96.09 0.072830 14.500 630.5 0.031100 Benign
3 65.13 0.062270 10.230 314.9 0.020760 Benign
....
### 2. knn 알고리즘 사용을 위한 표준화(standardization) 작업
> cancer2[,-length(cancer2)] <- apply(cancer2[,-length(cancer2)], 2, scale)
> cancer2
perimeter_worst concave_points_worst radius_worst area_worst concave_points_mean diagnosis
1 -0.412872793 0.003704444 -0.42523184 -0.4598226 -0.229525866 Benign
2 -0.512617398 -0.810686348 -0.54250400 -0.5881740 -0.625081270 Benign
3 -1.368044814 -0.964339497 -1.36340919 -1.0901279 -0.869847451 Benign
...
### 3. data equal sampling
> rn <- createDataPartition(y=cancer2$diagnosis, p = 0.7, list = F) # row 색인을 위해 list = F
> train <- cancer2[rn,]
> test <- cancer2[-rn,]
> table(train$diagnosis)
Benign Malignant
250 250
> table(test$diagnosis)
Benign Malignant
107 107
### 4. model fit
> train_x <- train[,!colnames(train) %in% 'diagnosis'] # 설명변수 train data # 특정 컬럼명 제외 색인 방식(선호)
> train_y <- train[,'diagnosis'] # 종속변수 train data
> test_x <- test[,-length(test)] # 설명변수 test data (모델 적용용) # length를 사용한 특정 컬럼 제외 색인 방식
> test_y <- test[,'diagnosis'] # 종속변수 test data (정답 확인용)
> library(class)
> knn_m <- knn(train = train_x, test = test_x, cl = train_y, k=3, prob=TRUE)
> knn_m # 근접한 k에 포함된 이웃 데이터가 모두 같은 그룹이라면 비율이 1로 표현
[1] Benign Benign Benign Benign Benign Malignant Benign Benign Benign Benign
...
[1] 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 0.6000000 1.0000000 1.0000000 1.0000000 1.0000000
....
### 5. model 평가
> sum(knn_m == test_y) / nrow(test_x) * 100 # k=3일 경우 정확도(score)
[1] 95.3271
### 6. k의 수 튜닝
# 1) test data set score
> score_test <- c()
> for(i in 1:20) {
knn_m <- knn(train = train_x, test = test_x, cl = train_y, k=i, prob=TRUE)
score_test <- c(score_test, sum(knn_m == test_y) / nrow(test_x) * 100)
}
> score_test
[1] 95.79439 94.39252 95.32710 95.32710 95.79439 94.85981 94.39252 93.45794 94.85981 93.92523 94.85981 93.92523 93.92523 93.45794 94.85981 94.85981 94.39252 93.92523 93.92523 93.92523
# 2) train data set score
> score_train <- c()
> for(i in 1:20) {
knn_m <- knn(train = train_x, test = train_x, cl = train_y, k=i, prob=TRUE)
score_train <- c(score_train, sum(knn_m == train_y) / nrow(train_x) * 100)
}
> score_train
[1] 100.0 96.8 96.0 94.6 95.6 95.4 95.2 95.4 94.6 94.8 95.4 95.2 94.2 94.2 94.8 95.0 94.6 94.4 94.6 95.2
# k의 증가에 따른 정확도
> plot(score_test, type = 'o', xlab = 'number of k', col='red', ylim = c(90,100))
> lines(score_train, type = 'o')
# 정확도를 높일지 or 과대적합(overfit)을 줄일지에 대한 모델러의 판단에 의해 k값이 설정됩니다.
위 그래프를 보았을 때, 정확도와 과대적합이 무난한(?) k값은 4인 것 같습니다.
확인해보니 96.26168 의 정확도를 보여주네요 !
참고: KIC 캠퍼스 머신러닝기반의 빅데이터분석 양성과정
'R > Analysis' 카테고리의 다른 글
[R 분석] 비계층적 군집 분석(k-means clustering) (0) | 2019.01.22 |
---|---|
[R 분석] 계층적 군집 분석(hierarchical clustering) (0) | 2019.01.21 |
[R 분석] 중요도가 높은 핵심 변수 선택하기 (0) | 2019.01.18 |
[R 분석] Random Forest 매개변수 튜닝 (1) | 2019.01.18 |
[R 분석] Random Forest 알고리즘 (0) | 2019.01.17 |