Arvores de Classificação - Única e GBM

Author

Ricardo Accioly

Published

August 20, 2024

Bibliotecas

Dados

Vamos começar a aplicar a metodologia de árvores usando árvores de classificação para analisar os dados existentes em Carseats. Este conjunto de dados (simulado) é sobre venda de assentos de criança para carros. Ele tem 400 observações das seguintes variáveis (11), cujos nomes serão convertidos para o português:

Sales: vendas em unidades (em mil) em cada local

CompPrice: preço cobrado pelo competidor em cada local

Income: nível de renda da comunidade local (em mil US$)

Advertising: orçamento local de propaganda (em mil US$)

Population: população na região (em mil)

Price: preço cobrado pela empresa em cada local

ShelveLoc: um fator com níveis Ruim, Bom e Medio indicando a qualidade da localização das prateleiras para os assentos em cada lugar

Age: idade media da população local

Education: nível de educação em cada local

Urban: um fator Sim e Não indicando se a loja esta em uma área urbana ou rural

US: um fator indicando se a loja é nos EUA ou não

Neste dados, Sales é a variável resposta, só que ela é uma variável contínua, por este motivo vamos usá-la para criar uma variável binária. Vamos usar a função ifelse() para criar a variável binária, que chamaremos de alta, ela assume os valores Sim se Sales for maior que 8 e assume o valor Não caso contrário:

data(Carseats)
summary(Carseats)
     Sales          CompPrice       Income        Advertising    
 Min.   : 0.000   Min.   : 77   Min.   : 21.00   Min.   : 0.000  
 1st Qu.: 5.390   1st Qu.:115   1st Qu.: 42.75   1st Qu.: 0.000  
 Median : 7.490   Median :125   Median : 69.00   Median : 5.000  
 Mean   : 7.496   Mean   :125   Mean   : 68.66   Mean   : 6.635  
 3rd Qu.: 9.320   3rd Qu.:135   3rd Qu.: 91.00   3rd Qu.:12.000  
 Max.   :16.270   Max.   :175   Max.   :120.00   Max.   :29.000  
   Population        Price        ShelveLoc        Age          Education   
 Min.   : 10.0   Min.   : 24.0   Bad   : 96   Min.   :25.00   Min.   :10.0  
 1st Qu.:139.0   1st Qu.:100.0   Good  : 85   1st Qu.:39.75   1st Qu.:12.0  
 Median :272.0   Median :117.0   Medium:219   Median :54.50   Median :14.0  
 Mean   :264.8   Mean   :115.8                Mean   :53.32   Mean   :13.9  
 3rd Qu.:398.5   3rd Qu.:131.0                3rd Qu.:66.00   3rd Qu.:16.0  
 Max.   :509.0   Max.   :191.0                Max.   :80.00   Max.   :18.0  
 Urban       US     
 No :118   No :142  
 Yes:282   Yes:258  
                    
                    
                    
                    
str(Carseats)
'data.frame':   400 obs. of  11 variables:
 $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
 $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
 $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
 $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
 $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
 $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
 $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
 $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
 $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
 $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
 $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...

Manipulando os dados

cad_crianca <- Carseats %>% rename(vendas = Sales, 
                                   preco_comp = CompPrice,
                                   renda = Income,
                                   propaganda = Advertising,
                                   populacao = Population,
                                   preco = Price,
                                   local_prat = ShelveLoc,
                                   idade = Age,
                                   educacao = Education,
                                   urbano = Urban,
                                   eua = US)

cad_crianca <- cad_crianca %>% mutate(alta = ifelse(vendas > 8, "Sim",
                                                   "Não")) %>%
                              mutate(alta = factor(alta))

cad_crianca<- cad_crianca %>% mutate(local_prat =  case_when(
                                      local_prat == "Bad"  ~ "Ruim",
                                      local_prat == "Good" ~ "Bom",
                                      local_prat == "Medium" ~ "Medio"))%>%                               mutate(local_prat = factor(local_prat))

cad_crianca<- cad_crianca %>% mutate(urbano =  case_when(
                                      urbano == "Yes"  ~ "Sim",
                                      urbano == "No" ~ "Não")) %>%                                       mutate(urbano = factor(urbano))

cad_crianca<- cad_crianca %>% mutate(eua =  case_when(
                                      eua == "Yes"  ~ "Sim",
                                      eua == "No" ~ "Não")) %>%                                          mutate(eua = factor(eua))

cad_crianca<- cad_crianca %>% select(-vendas)

str(cad_crianca)
'data.frame':   400 obs. of  11 variables:
 $ preco_comp: num  138 111 113 117 141 124 115 136 132 132 ...
 $ renda     : num  73 48 35 100 64 113 105 81 110 113 ...
 $ propaganda: num  11 16 10 4 3 13 0 15 0 0 ...
 $ populacao : num  276 260 269 466 340 501 45 425 108 131 ...
 $ preco     : num  120 83 80 97 128 72 108 120 124 124 ...
 $ local_prat: Factor w/ 3 levels "Bom","Medio",..: 3 1 2 2 3 3 2 1 2 2 ...
 $ idade     : num  42 65 59 55 38 78 71 67 76 76 ...
 $ educacao  : num  17 10 12 14 13 16 15 10 10 17 ...
 $ urbano    : Factor w/ 2 levels "Não","Sim": 2 2 2 2 2 1 2 2 1 1 ...
 $ eua       : Factor w/ 2 levels "Não","Sim": 2 2 2 2 1 2 1 2 1 2 ...
 $ alta      : Factor w/ 2 levels "Não","Sim": 2 2 2 1 1 2 1 2 1 1 ...
summary(cad_crianca)
   preco_comp      renda          propaganda       populacao    
 Min.   : 77   Min.   : 21.00   Min.   : 0.000   Min.   : 10.0  
 1st Qu.:115   1st Qu.: 42.75   1st Qu.: 0.000   1st Qu.:139.0  
 Median :125   Median : 69.00   Median : 5.000   Median :272.0  
 Mean   :125   Mean   : 68.66   Mean   : 6.635   Mean   :264.8  
 3rd Qu.:135   3rd Qu.: 91.00   3rd Qu.:12.000   3rd Qu.:398.5  
 Max.   :175   Max.   :120.00   Max.   :29.000   Max.   :509.0  
     preco       local_prat      idade          educacao    urbano     eua     
 Min.   : 24.0   Bom  : 85   Min.   :25.00   Min.   :10.0   Não:118   Não:142  
 1st Qu.:100.0   Medio:219   1st Qu.:39.75   1st Qu.:12.0   Sim:282   Sim:258  
 Median :117.0   Ruim : 96   Median :54.50   Median :14.0                      
 Mean   :115.8               Mean   :53.32   Mean   :13.9                      
 3rd Qu.:131.0               3rd Qu.:66.00   3rd Qu.:16.0                      
 Max.   :191.0               Max.   :80.00   Max.   :18.0                      
  alta    
 Não:236  
 Sim:164  
          
          
          
          

Treino e Teste

library(caret)
set.seed(21)
y <- cad_crianca$alta
indice_teste <- createDataPartition(y, times = 1, p = 0.2, list = FALSE)

conj_treino <- cad_crianca %>% slice(-indice_teste)
conj_teste <- cad_crianca %>% slice(indice_teste)

str(conj_treino)
'data.frame':   319 obs. of  11 variables:
 $ preco_comp: num  138 111 113 141 124 136 132 121 117 122 ...
 $ renda     : num  73 48 35 64 113 81 110 78 94 35 ...
 $ propaganda: num  11 16 10 3 13 15 0 9 4 2 ...
 $ populacao : num  276 260 269 340 501 425 108 150 503 393 ...
 $ preco     : num  120 83 80 128 72 120 124 100 94 136 ...
 $ local_prat: Factor w/ 3 levels "Bom","Medio",..: 3 1 2 3 3 1 2 3 1 2 ...
 $ idade     : num  42 65 59 38 78 67 76 26 50 62 ...
 $ educacao  : num  17 10 12 13 16 10 10 10 13 18 ...
 $ urbano    : Factor w/ 2 levels "Não","Sim": 2 2 2 2 1 2 1 1 2 2 ...
 $ eua       : Factor w/ 2 levels "Não","Sim": 2 2 2 1 2 2 1 2 2 1 ...
 $ alta      : Factor w/ 2 levels "Não","Sim": 2 2 2 1 2 2 1 2 2 1 ...
prop.table(table(conj_treino$alta))

      Não       Sim 
0.5893417 0.4106583 
str(conj_teste)
'data.frame':   81 obs. of  11 variables:
 $ preco_comp: num  117 115 132 115 147 145 114 121 123 103 ...
 $ renda     : num  100 105 113 28 74 119 38 41 42 93 ...
 $ propaganda: num  4 0 0 11 13 16 13 5 11 15 ...
 $ populacao : num  466 45 131 29 251 294 317 412 16 188 ...
 $ preco     : num  97 108 124 86 131 113 128 110 134 103 ...
 $ local_prat: Factor w/ 3 levels "Bom","Medio",..: 2 2 2 1 1 3 1 2 2 3 ...
 $ idade     : num  55 71 76 53 52 42 50 54 59 74 ...
 $ educacao  : num  14 15 17 18 10 12 16 10 13 16 ...
 $ urbano    : Factor w/ 2 levels "Não","Sim": 2 2 1 2 2 2 2 2 2 2 ...
 $ eua       : Factor w/ 2 levels "Não","Sim": 2 1 2 2 2 2 2 2 2 2 ...
 $ alta      : Factor w/ 2 levels "Não","Sim": 1 1 1 2 2 2 2 1 1 1 ...
prop.table(table(conj_teste$alta))

      Não       Sim 
0.5925926 0.4074074 

Arvore de Classificação

Na biblioteca rpart as arvores de classificação são obtidas usando o método class. Existem alguns controles que podem ser feitos nos parametros da arvore.

Neste exemplo só definimos o menor conjunto de dados numa partição (minsplit) e o parametro de complexidade cp. Posteriormente vamos ampliar este controle. Um valor de cp muito pequeno ocasiona overfitting e um valor muito grande resulta numa arvore muito pequena (underfitting). Nos dois casos se diminui o desempenho do modelo.

##Usando rpart para desenvolver a arvore  
library(rpart)
arvcl <- rpart(alta ~ ., 
                data=conj_treino,
                method="class", #para arvore de classificação
                control=rpart.control(minsplit=30,cp=0.02))
plot(arvcl)
text(arvcl,pretty=0)

Regras

# Regras de Decisão
arvcl
n= 319 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 319 131 Não (0.58934169 0.41065831)  
   2) local_prat=Medio,Ruim 249  77 Não (0.69076305 0.30923695)  
     4) preco>=92 213  51 Não (0.76056338 0.23943662)  
       8) idade>=49.5 123  15 Não (0.87804878 0.12195122) *
       9) idade< 49.5 90  36 Não (0.60000000 0.40000000)  
        18) preco>=124.5 44   6 Não (0.86363636 0.13636364) *
        19) preco< 124.5 46  16 Sim (0.34782609 0.65217391) *
     5) preco< 92 36  10 Sim (0.27777778 0.72222222)  
      10) local_prat=Ruim 13   5 Não (0.61538462 0.38461538) *
      11) local_prat=Medio 23   2 Sim (0.08695652 0.91304348) *
   3) local_prat=Bom 70  16 Sim (0.22857143 0.77142857)  
     6) preco>=142.5 10   3 Não (0.70000000 0.30000000) *
     7) preco< 142.5 60   9 Sim (0.15000000 0.85000000) *

Desenhando a Árvore de uma forma mais clara

library(rattle)
library(rpart.plot)
library(RColorBrewer)
fancyRpartPlot(arvcl, caption = NULL)

Previsões

# Fazendo Previsões
y_chapeu <- predict(arvcl, newdata = conj_teste, type="class")

confusionMatrix(y_chapeu, conj_teste$alta, positive="Sim") 
Confusion Matrix and Statistics

          Reference
Prediction Não Sim
       Não  37  13
       Sim  11  20
                                          
               Accuracy : 0.7037          
                 95% CI : (0.5919, 0.8001)
    No Information Rate : 0.5926          
    P-Value [Acc > NIR] : 0.02573         
                                          
                  Kappa : 0.3805          
                                          
 Mcnemar's Test P-Value : 0.83826         
                                          
            Sensitivity : 0.6061          
            Specificity : 0.7708          
         Pos Pred Value : 0.6452          
         Neg Pred Value : 0.7400          
             Prevalence : 0.4074          
         Detection Rate : 0.2469          
   Detection Prevalence : 0.3827          
      Balanced Accuracy : 0.6884          
                                          
       'Positive' Class : Sim             
                                          

Arvore de Classificação no caret

##Usando rpart para desenvolver a arvore  
library(rpart)
set.seed(21)
## Otimizamos o valor de cp usando um 10-fold cv
# O parametro tuneLength diz para o algoritmo escolher diferentes valores para cp
# O parametro tuneGrid permite decidir que valores cp deve assumir enquanto que o
# tuneLength somente limita o número default de parametros que se usa.
tgrid <- expand.grid(cp = seq(0.01,0.10,0.001))
ctrl <- trainControl(method = "cv", classProbs=TRUE)
arvclass <- train(alta ~ . , data = conj_treino, method = "rpart",
                 trControl = ctrl,
                 tuneGrid = tgrid
                 )
# Mostra a acurácia vs cp (parametro de complexidade)
plot(arvclass)

## Indica o melhor valor de cp
arvclass$bestTune
     cp
7 0.016

Uma forma melhor de ver a Árvore

## melhorando apresentação da árvore
library(rattle)
library(rpart.plot)
library(RColorBrewer)
fancyRpartPlot(arvclass$finalModel, caption = NULL)

Previsões

# Fazendo Previsões
y_chapeu <- arvclass %>% predict(conj_teste) %>% 
                   factor(levels = levels(conj_teste$alta))
head(y_chapeu)
[1] Não Não Não Sim Sim Não
Levels: Não Sim
confusionMatrix(y_chapeu, conj_teste$alta, positive="Sim") 
Confusion Matrix and Statistics

          Reference
Prediction Não Sim
       Não  40  13
       Sim   8  20
                                          
               Accuracy : 0.7407          
                 95% CI : (0.6314, 0.8318)
    No Information Rate : 0.5926          
    P-Value [Acc > NIR] : 0.003896        
                                          
                  Kappa : 0.45            
                                          
 Mcnemar's Test P-Value : 0.382733        
                                          
            Sensitivity : 0.6061          
            Specificity : 0.8333          
         Pos Pred Value : 0.7143          
         Neg Pred Value : 0.7547          
             Prevalence : 0.4074          
         Detection Rate : 0.2469          
   Detection Prevalence : 0.3457          
      Balanced Accuracy : 0.7197          
                                          
       'Positive' Class : Sim             
                                          

Verificando a consistencia dos resultados

set.seed(121)
y <- cad_crianca$alta
indice_teste <- createDataPartition(y, times = 1, p = 0.2, list = FALSE)

conj_treino <- cad_crianca %>% slice(-indice_teste)
conj_teste <- cad_crianca %>% slice(indice_teste)

str(conj_treino)
'data.frame':   319 obs. of  11 variables:
 $ preco_comp: num  138 113 141 124 115 136 132 121 117 115 ...
 $ renda     : num  73 35 64 113 105 81 113 78 94 28 ...
 $ propaganda: num  11 10 3 13 0 15 0 9 4 11 ...
 $ populacao : num  276 269 340 501 45 425 131 150 503 29 ...
 $ preco     : num  120 80 128 72 108 120 124 100 94 86 ...
 $ local_prat: Factor w/ 3 levels "Bom","Medio",..: 3 2 3 3 2 1 2 3 1 1 ...
 $ idade     : num  42 59 38 78 71 67 76 26 50 53 ...
 $ educacao  : num  17 12 13 16 15 10 17 10 13 18 ...
 $ urbano    : Factor w/ 2 levels "Não","Sim": 2 2 2 1 2 2 1 1 2 2 ...
 $ eua       : Factor w/ 2 levels "Não","Sim": 2 2 1 2 1 2 2 2 2 2 ...
 $ alta      : Factor w/ 2 levels "Não","Sim": 2 2 1 2 1 2 1 2 2 2 ...
prop.table(table(conj_treino$alta))

      Não       Sim 
0.5893417 0.4106583 
str(conj_teste)
'data.frame':   81 obs. of  11 variables:
 $ preco_comp: num  111 117 132 122 125 139 103 125 122 121 ...
 $ renda     : num  48 100 110 35 90 32 74 94 76 90 ...
 $ propaganda: num  16 4 0 2 2 0 0 0 0 0 ...
 $ populacao : num  260 466 108 393 367 176 359 447 270 150 ...
 $ preco     : num  83 97 124 136 131 82 97 89 100 108 ...
 $ local_prat: Factor w/ 3 levels "Bom","Medio",..: 1 2 2 2 2 1 3 1 1 3 ...
 $ idade     : num  65 55 76 62 35 54 55 30 60 75 ...
 $ educacao  : num  10 14 10 18 18 11 11 12 18 16 ...
 $ urbano    : Factor w/ 2 levels "Não","Sim": 2 2 1 2 2 1 2 2 1 2 ...
 $ eua       : Factor w/ 2 levels "Não","Sim": 2 2 1 1 2 1 2 1 1 1 ...
 $ alta      : Factor w/ 2 levels "Não","Sim": 2 1 1 1 1 2 1 2 2 1 ...
prop.table(table(conj_teste$alta))

      Não       Sim 
0.5925926 0.4074074 

Obtendo a arvore

##Usando rpart para desenvolver a arvore  
library(rpart)
arvcl <- rpart(alta ~ ., 
                data=conj_treino,
                method="class", #para arvore de classificação
                control=rpart.control(minsplit=30,cp=0.02))

Desenhando a Árvore de uma forma mais clara

library(rattle)
library(rpart.plot)
library(RColorBrewer)
fancyRpartPlot(arvcl, caption = NULL)

GBM

Criando um grid para avaliar os parametros

hiper_grid <- expand.grid(
  shrinkage = c(.001, .01, .1),
  interaction.depth = c(1, 3, 5),
  n.minobsinnode = c(5, 10, 15),
  bag.fraction = c(.65, 1),
  optimal_trees = 0, # um lugar para guardar resultados
  min_erro = 0   # um lugar para guardar resultados
  )

# número total de combinações
nrow(hiper_grid)
[1] 54

Avaliando o grid de parametros

library(gbm)
conj_treino$alta <- as.numeric(conj_treino$alta)
conj_treino <- transform(conj_treino, alta=alta - 1)
conj_treino$eua <- as.numeric(conj_treino$eua)
conj_treino <- transform(conj_treino, eua=eua - 1)
conj_treino$local_prat <- as.numeric(conj_treino$local_prat)
conj_treino <- transform(conj_treino, local_prat=local_prat - 1)


#Busca no grid
for(i in 1:nrow(hiper_grid)) {

  #
  set.seed(21)

  # treina o modelo
  gbm.tune <- gbm(
    formula = alta ~ .,
    distribution = "bernoulli",
    data = conj_treino,
    n.trees = 5000,
    interaction.depth = hiper_grid$interaction.depth[i],
    shrinkage = hiper_grid$shrinkage[i],
    n.minobsinnode = hiper_grid$n.minobsinnode[i],
    bag.fraction = hiper_grid$bag.fraction[i],
    train.fraction = .75,
    n.cores = NULL,
    verbose = FALSE
  )

  # adiciona os erros de treino e arvores ao grid
  hiper_grid$optimal_trees[i] <- which.min(gbm.tune$valid.error)
  hiper_grid$min_erro[i] <- min(gbm.tune$valid.error)
}

hiper_grid %>% dplyr::arrange(min_erro) %>% head(10)
   shrinkage interaction.depth n.minobsinnode bag.fraction optimal_trees
1       0.10                 3              5         0.65           165
2       0.01                 1             10         0.65          3778
3       0.10                 1             10         0.65           409
4       0.01                 1              5         0.65          3577
5       0.10                 1             15         0.65           295
6       0.01                 1             15         0.65          3555
7       0.10                 1              5         0.65           322
8       0.10                 5              5         0.65           134
9       0.01                 3              5         0.65          1515
10      0.01                 3             10         0.65          1525
    min_erro
1  0.5223560
2  0.5265243
3  0.5273196
4  0.5292021
5  0.5311457
6  0.5364523
7  0.5377417
8  0.5453332
9  0.5509056
10 0.5576744

Modelo final

# 
set.seed(21)

# treina o modelo GBM
gbm.fit.final <- gbm(
  formula = alta ~ .,
  distribution = "bernoulli",
  data = conj_treino,
  n.trees = 165,
  interaction.depth = 3,
  shrinkage = 0.10,
  n.minobsinnode = 5,
  bag.fraction = 0.65, 
  train.fraction = 1,
  n.cores = NULL, 
  verbose = FALSE
  )  

Importância das Variáveis

par(mai = c(1, 2, 1, 2))
summary(
  gbm.fit.final, 
  cBars = 10,
  method = relative.influence, # também pode ser usado permutation.test.gbm
  las = 2
  )

                  var    rel.inf
preco           preco 29.5468476
local_prat local_prat 21.5292204
propaganda propaganda 12.2498142
renda           renda 11.4616394
idade           idade 10.3960594
preco_comp preco_comp 10.1224807
populacao   populacao  2.0513691
eua               eua  1.2347441
educacao     educacao  1.1121800
urbano         urbano  0.2956451

Previsão

conj_teste$alta <- as.numeric(conj_teste$alta)
conj_teste <- transform(conj_teste, alta=alta - 1)
conj_teste$eua <- as.numeric(conj_teste$eua)
conj_teste <- transform(conj_teste, eua=eua - 1)
conj_teste$local_prat <- as.numeric(conj_teste$local_prat)
conj_teste <- transform(conj_teste, local_prat=local_prat - 1)

# Fazendo Previsões
previsao1 <- predict(gbm.fit.final, 
                     newdata = conj_teste,
                     n.trees=gbm.fit.final$n.trees,
                     type = "response")
head(previsao1)
[1] 0.965002359 0.784188568 0.176077409 0.003460984 0.027812237 0.995643378
gbm.ychapeu <- as.factor(ifelse(previsao1 < 0.5,0,1))
 
confusionMatrix(gbm.ychapeu,as.factor(conj_teste$alta), positive="1")
Confusion Matrix and Statistics

          Reference
Prediction  0  1
         0 37  6
         1 11 27
                                          
               Accuracy : 0.7901          
                 95% CI : (0.6854, 0.8727)
    No Information Rate : 0.5926          
    P-Value [Acc > NIR] : 0.0001359       
                                          
                  Kappa : 0.5754          
                                          
 Mcnemar's Test P-Value : 0.3319755       
                                          
            Sensitivity : 0.8182          
            Specificity : 0.7708          
         Pos Pred Value : 0.7105          
         Neg Pred Value : 0.8605          
             Prevalence : 0.4074          
         Detection Rate : 0.3333          
   Detection Prevalence : 0.4691          
      Balanced Accuracy : 0.7945          
                                          
       'Positive' Class : 1               
                                          

Curva ROC

library(pROC)
p_chapeu_gbm <- previsao1
roc_gbm <- roc(conj_teste$alta ~ p_chapeu_gbm, plot = TRUE, print.auc=FALSE, col="black", legacy.axes=TRUE)

as.numeric(roc_gbm$auc)
[1] 0.9034091