KNN

Author

Ricardo Accioly

Published

October 28, 2025

KNN

O KNN é um algoritmo muito simples no qual cada observação é prevista com base em sua “semelhança” com outras observações. Ao contrário da maioria dos métodos, KNN é um algoritmo baseado na memória e não pode ser resumido por um modelo de forma fechada. Isso significa que as amostras de treinamento são necessárias no tempo de execução e as previsões são feitas diretamente das relações amostrais. Consequentemente, os KNNs também são conhecidos como aprendizes preguiçosos

Carregando Bibliotecas

#>  default    student       balance           income     
#>  No :9667   No :7056   Min.   :   0.0   Min.   :  772  
#>  Yes: 333   Yes:2944   1st Qu.: 481.7   1st Qu.:21340  
#>                        Median : 823.6   Median :34553  
#>                        Mean   : 835.4   Mean   :33517  
#>                        3rd Qu.:1166.3   3rd Qu.:43808  
#>                        Max.   :2654.3   Max.   :73554
str(Default)
#> 'data.frame':    10000 obs. of  4 variables:
#>  $ default: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
#>  $ student: Factor w/ 2 levels "No","Yes": 1 2 1 1 1 2 1 2 1 1 ...
#>  $ balance: num  730 817 1074 529 786 ...
#>  $ income : num  44362 12106 31767 35704 38463 ...
head(Default)
#>   default student   balance    income
#> 1      No      No  729.5265 44361.625
#> 2      No     Yes  817.1804 12106.135
#> 3      No      No 1073.5492 31767.139
#> 4      No      No  529.2506 35704.494
#> 5      No      No  785.6559 38463.496
#> 6      No     Yes  919.5885  7491.559

Manipulando os dados

credito <- tibble(Default)
summary(credito)
#>  default    student       balance           income     
#>  No :9667   No :7056   Min.   :   0.0   Min.   :  772  
#>  Yes: 333   Yes:2944   1st Qu.: 481.7   1st Qu.:21340  
#>                        Median : 823.6   Median :34553  
#>                        Mean   : 835.4   Mean   :33517  
#>                        3rd Qu.:1166.3   3rd Qu.:43808  
#>                        Max.   :2654.3   Max.   :73554
# renomeando colunas
credito <- credito %>% 
                rename( inadimplente = default, estudante = student, balanco = balance,
                receita = income)
credito <- credito %>% mutate( inadimplente =  case_when(
                           inadimplente == "No"  ~ "Nao",
                           inadimplente == "Yes" ~ "Sim"
                          )) %>% mutate(inadimplente = factor(inadimplente))
credito <- credito %>% mutate( estudante =  case_when(
                           estudante == "No"  ~ 0,
                           estudante == "Yes" ~ 1
                          )) 

str(credito)
#> tibble [10,000 × 4] (S3: tbl_df/tbl/data.frame)
#>  $ inadimplente: Factor w/ 2 levels "Nao","Sim": 1 1 1 1 1 1 1 1 1 1 ...
#>  $ estudante   : num [1:10000] 0 1 0 0 0 1 0 1 0 0 ...
#>  $ balanco     : num [1:10000] 730 817 1074 529 786 ...
#>  $ receita     : num [1:10000] 44362 12106 31767 35704 38463 ...
summary(credito)
#>  inadimplente   estudante         balanco          receita     
#>  Nao:9667     Min.   :0.0000   Min.   :   0.0   Min.   :  772  
#>  Sim: 333     1st Qu.:0.0000   1st Qu.: 481.7   1st Qu.:21340  
#>               Median :0.0000   Median : 823.6   Median :34553  
#>               Mean   :0.2944   Mean   : 835.4   Mean   :33517  
#>               3rd Qu.:1.0000   3rd Qu.:1166.3   3rd Qu.:43808  
#>               Max.   :1.0000   Max.   :2654.3   Max.   :73554

Graficos de Densidade

Vamos agora explorar os dados originais para termos algum visão do comportamento das variáveis explicativas e a variável dependente.

featurePlot(x = credito[, c("balanco", "receita", "estudante")], 
            y = credito$inadimplente,
            plot = "density", 
            scales = list(x = list(relation = "free"), 
                          y = list(relation = "free")), 
            adjust = 1.5, 
            pch = "|", 
            layout = c(2, 1), 
            auto.key = list(columns = 2))

Avaliando o comportamento das variáveis em função do status (inadimplente / estudante)

p1 <- ggplot(credito, aes(x=inadimplente, y=balanco, color=inadimplente)) +
  geom_boxplot()
p2 <- ggplot(credito, aes(x=inadimplente, y=receita, color=inadimplente)) +
  geom_boxplot()
p3 <- ggplot(credito, aes(x=as.factor(estudante), y=balanco, color=as.factor(estudante))) +
  geom_boxplot()
p4 <- ggplot(credito, aes(x=as.factor(estudante), y=receita, color=as.factor(estudante))) +
  geom_boxplot()
(p1 + p2) / (p3 + p4)

Balanço vs Receita

ggplot(data = credito, aes(x=balanco,  y = receita, col = inadimplente)) + geom_point() 

KNN

Vamos usar a função knn da biblioteca caret que tem ótimas funcionalidades. Observem que a saída pode ser as classes ou as probabilidades de pertencer a uma classe

Como o KNN usa as distancias entre os pontos ele é afetado pela escala dos dados, portanto, é necessário que os dados sejam normalizados (padronizados) para eliminar este efeito.

Quando temos diversas variáveis explicativas em diferentes escalas, em geral, elas devem ser transformadas para ter media zero e desvio padrão 1

Criando conjuntos de treino e teste e normalizando variáveis

set.seed(2025)
y <- credito$inadimplente
credito_split <- createDataPartition(y, times = 1, p = 0.10, list = FALSE)

conj_treino <- credito[-credito_split,]
conj_treino[,3:4] <- scale(conj_treino[,3:4]) # scale normaliza
conj_teste <- credito[credito_split,]
conj_teste[,3:4] <- scale(conj_teste[, 3:4])
                           
summary(conj_treino)
#>  inadimplente   estudante         balanco            receita       
#>  Nao:8700     Min.   :0.0000   Min.   :-1.73003   Min.   :-2.4534  
#>  Sim: 299     1st Qu.:0.0000   1st Qu.:-0.72954   1st Qu.:-0.9140  
#>               Median :0.0000   Median :-0.03104   Median : 0.0764  
#>               Mean   :0.2946   Mean   : 0.00000   Mean   : 0.0000  
#>               3rd Qu.:1.0000   3rd Qu.: 0.68437   3rd Qu.: 0.7707  
#>               Max.   :1.0000   Max.   : 3.76117   Max.   : 2.9933
summary(conj_teste)
#>  inadimplente   estudante         balanco            receita        
#>  Nao:967      Min.   :0.0000   Min.   :-1.69940   Min.   :-2.12069  
#>  Sim: 34      1st Qu.:0.0000   1st Qu.:-0.76771   1st Qu.:-0.90454  
#>               Median :0.0000   Median : 0.02552   Median : 0.06997  
#>               Mean   :0.2927   Mean   : 0.00000   Mean   : 0.00000  
#>               3rd Qu.:1.0000   3rd Qu.: 0.67537   3rd Qu.: 0.73936  
#>               Max.   :1.0000   Max.   : 2.86197   Max.   : 2.86471

1a Modelo

Vamos usar a regra da raiz quadrada do tamanho da amostra para definir o número de vizinhos do KNN.

k <- round(sqrt(nrow(conj_treino)),0) # número de vizinhos
k
#> [1] 95
set.seed(2025)
t_knn1 <- knn3(inadimplente ~ balanco + receita + estudante, data = conj_treino, k = k)
t_knn1
#> 95-nearest neighbor model
#> Training set outcome distribution:
#> 
#>  Nao  Sim 
#> 8700  299

Avaliando o modelo

Através da função matriz de confusão do pacote caret conseguimos obter as principais medidas de avaliação de um modelo de classificação.

Veja que a acurácia deu um valor alto, mas isto não é suficiente para considerarmos que temos um bom modelo. Veja que a sensibilidade está muito baixa e que o ideal é que tenhamos valores altos de sensibilidade e especificidade.

Observar que a prevalência é muito baixa o que está afetando os resultados do modelo.

y_chapeu_knn1 <- predict(t_knn1, conj_teste, type = "class")


confusionMatrix(y_chapeu_knn1, conj_teste$inadimplente, positive="Sim") 
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction Nao Sim
#>        Nao 964  28
#>        Sim   3   6
#>                                           
#>                Accuracy : 0.969           
#>                  95% CI : (0.9563, 0.9789)
#>     No Information Rate : 0.966           
#>     P-Value [Acc > NIR] : 0.3395          
#>                                           
#>                   Kappa : 0.2687          
#>                                           
#>  Mcnemar's Test P-Value : 1.629e-05       
#>                                           
#>             Sensitivity : 0.176471        
#>             Specificity : 0.996898        
#>          Pos Pred Value : 0.666667        
#>          Neg Pred Value : 0.971774        
#>              Prevalence : 0.033966        
#>          Detection Rate : 0.005994        
#>    Detection Prevalence : 0.008991        
#>       Balanced Accuracy : 0.586684        
#>                                           
#>        'Positive' Class : Sim             
#> 

Curva ROC

Para a curva ROC é necessário que obtenhamos as probabilidades e não das classes, vejam nos comandos abaixo como se obtem as probabilidades.

# 
p_chapeu_knn1 <- predict(t_knn1, conj_teste, type = "prob")
head(p_chapeu_knn1)
#>            Nao        Sim
#> [1,] 0.9894737 0.01052632
#> [2,] 1.0000000 0.00000000
#> [3,] 1.0000000 0.00000000
#> [4,] 1.0000000 0.00000000
#> [5,] 0.9894737 0.01052632
#> [6,] 1.0000000 0.00000000
# Aqui gera o curva e salvo numa variável
roc_knn1 <- roc(conj_teste$inadimplente ~ p_chapeu_knn1[,2], plot = FALSE, print.auc=FALSE)

# Visualização com ggroc
ggroc(roc_knn1) +
  ggplot2::labs(title = "ROC - KNN", x = "1 - Especificidade", y = "Sensibilidade") +
  ggplot2::theme_minimal()

Area embaixo da curva ROC

# Area abaixo da Curva (AUC)
as.numeric(roc_knn1$auc)
#> [1] 0.9191405

Variando K

Anteriormente usamos k=95 . Este parametro, em geral, deve ser ajustado para melhoramos os modelo KNN. Para isto vamos usar a função train da biblioteca caret

Observe que a otimização de k, no exemplo abaixo, é feita através de acurácia. O k também pode ser otimizado usando o valor do AUC (área embaixo da curva ROC).

set.seed(2025)

# Usando validação cruzada para obter o valor de k através da função train da biblioteca caret e o controle do treino e fazendo um gride de valores para k.
ctrl <- trainControl(method = "repeatedcv", 
                     number = 5,
                     repeats = 5)
t_knn2 <- train(inadimplente ~ balanco + receita + estudante,
                method = "knn", 
                trControl= ctrl,
                tuneGrid = data.frame(k = seq(5,150, by=5)),
                metric = "Accuracy",
                data = conj_treino)
## Resultados do treino
t_knn2
#> k-Nearest Neighbors 
#> 
#> 8999 samples
#>    3 predictor
#>    2 classes: 'Nao', 'Sim' 
#> 
#> No pre-processing
#> Resampling: Cross-Validated (5 fold, repeated 5 times) 
#> Summary of sample sizes: 7199, 7199, 7199, 7199, 7200, 7199, ... 
#> Resampling results across tuning parameters:
#> 
#>   k    Accuracy   Kappa     
#>     5  0.9704413  0.42266071
#>    10  0.9716414  0.41185688
#>    15  0.9723526  0.41049306
#>    20  0.9730415  0.41527673
#>    25  0.9729749  0.39961432
#>    30  0.9729305  0.39434063
#>    35  0.9725527  0.37664935
#>    40  0.9724415  0.36463583
#>    45  0.9723749  0.35587285
#>    50  0.9722193  0.34677085
#>    55  0.9722415  0.34231099
#>    60  0.9719303  0.32682892
#>    65  0.9717303  0.31692192
#>    70  0.9715748  0.30754847
#>    75  0.9709079  0.27742350
#>    80  0.9706412  0.25982263
#>    85  0.9703523  0.24349134
#>    90  0.9702412  0.23166492
#>    95  0.9697301  0.20218244
#>   100  0.9690410  0.16434426
#>   105  0.9686633  0.13872824
#>   110  0.9683299  0.11503994
#>   115  0.9678409  0.07816959
#>   120  0.9676187  0.05780951
#>   125  0.9673964  0.04083628
#>   130  0.9672408  0.02760184
#>   135  0.9671297  0.02148245
#>   140  0.9670853  0.01773521
#>   145  0.9669741  0.01144419
#>   150  0.9669741  0.01144419
#> 
#> Accuracy was used to select the optimal model using the largest value.
#> The final value used for the model was k = 20.
plot(t_knn2)

## Previsões com o resultaddos do treino
prev_knn2 <- predict(t_knn2, conj_teste)
confusionMatrix(prev_knn2, conj_teste$inadimplente,  positive="Sim")
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction Nao Sim
#>        Nao 961  23
#>        Sim   6  11
#>                                           
#>                Accuracy : 0.971           
#>                  95% CI : (0.9587, 0.9805)
#>     No Information Rate : 0.966           
#>     P-Value [Acc > NIR] : 0.219168        
#>                                           
#>                   Kappa : 0.4182          
#>                                           
#>  Mcnemar's Test P-Value : 0.002967        
#>                                           
#>             Sensitivity : 0.32353         
#>             Specificity : 0.99380         
#>          Pos Pred Value : 0.64706         
#>          Neg Pred Value : 0.97663         
#>              Prevalence : 0.03397         
#>          Detection Rate : 0.01099         
#>    Detection Prevalence : 0.01698         
#>       Balanced Accuracy : 0.65866         
#>                                           
#>        'Positive' Class : Sim             
#> 

Curva ROC dos 2 melhores modelos

prev_knn1 <- predict(t_knn1, conj_teste, type = "prob")
prev_knn2 <- predict(t_knn2, conj_teste, type = "prob")
roc_knn1 <- roc(conj_teste$inadimplente ~ prev_knn1[,2], plot = FALSE, print.auc=FALSE)
roc_knn2 <- roc(conj_teste$inadimplente ~ prev_knn2[,2], plot = FALSE, print.auc=FALSE)

# Visualização com ggroc
ggroc(list(knn1= roc_knn1, knn2= roc_knn2)) +
  ggplot2::labs(title = "ROC - KNN", x = "1 - Especificidade", y = "Sensibilidade") +
  ggplot2::theme_minimal()

## Area embaixo das curvas
as.numeric(roc_knn1$auc)
#> [1] 0.9191405
as.numeric(roc_knn2$auc)
#> [1] 0.9041304

Observe que os resultados de área abaixo da ROC não são suficientes para a escolha do k, pois precisamos estar atentos ao valores de sensibilidade e especificidade! A depender da importância de cada um destes valores para o problema em questão, podemos escolher um k que nos dê um bom equilíbrio entre estes dois valores.

Reprodutibilidade

#> R version 4.5.1 (2025-06-13 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 11 x64 (build 26200)
#> 
#> Matrix products: default
#>   LAPACK version 3.12.1
#> 
#> locale:
#> [1] LC_COLLATE=Portuguese_Brazil.utf8  LC_CTYPE=Portuguese_Brazil.utf8   
#> [3] LC_MONETARY=Portuguese_Brazil.utf8 LC_NUMERIC=C                      
#> [5] LC_TIME=Portuguese_Brazil.utf8    
#> 
#> time zone: America/Sao_Paulo
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] ISLR_1.4        pROC_1.19.0.1   patchwork_1.3.1 caret_7.0-1    
#>  [5] lattice_0.22-7  lubridate_1.9.4 forcats_1.0.0   stringr_1.5.1  
#>  [9] dplyr_1.1.4     purrr_1.1.0     readr_2.1.5     tidyr_1.3.1    
#> [13] tibble_3.3.0    ggplot2_3.5.2   tidyverse_2.0.0
#> 
#> loaded via a namespace (and not attached):
#>  [1] gtable_0.3.6         xfun_0.52            htmlwidgets_1.6.4   
#>  [4] recipes_1.3.1        tzdb_0.5.0           vctrs_0.6.5         
#>  [7] tools_4.5.1          generics_0.1.4       stats4_4.5.1        
#> [10] parallel_4.5.1       proxy_0.4-27         ModelMetrics_1.2.2.2
#> [13] pkgconfig_2.0.3      Matrix_1.7-3         data.table_1.17.8   
#> [16] RColorBrewer_1.1-3   lifecycle_1.0.4      compiler_4.5.1      
#> [19] farver_2.1.2         codetools_0.2-20     htmltools_0.5.8.1   
#> [22] class_7.3-23         yaml_2.3.10          prodlim_2025.04.28  
#> [25] pillar_1.11.0        MASS_7.3-65          gower_1.0.2         
#> [28] iterators_1.0.14     rpart_4.1.24         foreach_1.5.2       
#> [31] nlme_3.1-168         parallelly_1.45.1    lava_1.8.1          
#> [34] tidyselect_1.2.1     digest_0.6.37        stringi_1.8.7       
#> [37] future_1.67.0        reshape2_1.4.4       listenv_0.9.1       
#> [40] labeling_0.4.3       splines_4.5.1        fastmap_1.2.0       
#> [43] grid_4.5.1           cli_3.6.5            magrittr_2.0.3      
#> [46] dichromat_2.0-0.1    survival_3.8-3       e1071_1.7-16        
#> [49] future.apply_1.20.0  withr_3.0.2          scales_1.4.0        
#> [52] timechange_0.3.0     rmarkdown_2.29       globals_0.18.0      
#> [55] nnet_7.3-20          timeDate_4041.110    hms_1.1.3           
#> [58] evaluate_1.0.4       knitr_1.50           hardhat_1.4.1       
#> [61] rlang_1.1.6          Rcpp_1.1.0           glue_1.8.0          
#> [64] ipred_0.9-15         rstudioapi_0.17.1    jsonlite_2.0.0      
#> [67] R6_2.6.1             plyr_1.8.9