Arvores de Classificação - GBM

Author

Ricardo Accioly

Published

May 30, 2025

Bibliotecas

Dados

Vamos começar a aplicar a metodologia de árvores usando árvores de classificação para analisar os dados existentes em Carseats. Este conjunto de dados (simulado) é sobre venda de assentos de criança para carros. Ele tem 400 observações das seguintes variáveis (11), cujos nomes serão convertidos para o português:

Sales: vendas em unidades (em mil) em cada local

CompPrice: preço cobrado pelo competidor em cada local

Income: nível de renda da comunidade local (em mil US$)

Advertising: orçamento local de propaganda (em mil US$)

Population: população na região (em mil)

Price: preço cobrado pela empresa em cada local

ShelveLoc: um fator com níveis Ruim, Bom e Medio indicando a qualidade da localização das prateleiras para os assentos em cada lugar

Age: idade media da população local

Education: nível de educação em cada local

Urban: um fator Sim e Não indicando se a loja esta em uma área urbana ou rural

US: um fator indicando se a loja é nos EUA ou não

Neste dados, Sales é a variável resposta, só que ela é uma variável contínua, por este motivo vamos usá-la para criar uma variável binária. Vamos usar a função ifelse() para criar a variável binária, que chamaremos de alta, ela assume os valores Sim se Sales for maior que 8 e assume o valor Não caso contrário:

data(Carseats)
summary(Carseats)
     Sales          CompPrice       Income        Advertising    
 Min.   : 0.000   Min.   : 77   Min.   : 21.00   Min.   : 0.000  
 1st Qu.: 5.390   1st Qu.:115   1st Qu.: 42.75   1st Qu.: 0.000  
 Median : 7.490   Median :125   Median : 69.00   Median : 5.000  
 Mean   : 7.496   Mean   :125   Mean   : 68.66   Mean   : 6.635  
 3rd Qu.: 9.320   3rd Qu.:135   3rd Qu.: 91.00   3rd Qu.:12.000  
 Max.   :16.270   Max.   :175   Max.   :120.00   Max.   :29.000  
   Population        Price        ShelveLoc        Age          Education   
 Min.   : 10.0   Min.   : 24.0   Bad   : 96   Min.   :25.00   Min.   :10.0  
 1st Qu.:139.0   1st Qu.:100.0   Good  : 85   1st Qu.:39.75   1st Qu.:12.0  
 Median :272.0   Median :117.0   Medium:219   Median :54.50   Median :14.0  
 Mean   :264.8   Mean   :115.8                Mean   :53.32   Mean   :13.9  
 3rd Qu.:398.5   3rd Qu.:131.0                3rd Qu.:66.00   3rd Qu.:16.0  
 Max.   :509.0   Max.   :191.0                Max.   :80.00   Max.   :18.0  
 Urban       US     
 No :118   No :142  
 Yes:282   Yes:258  
                    
                    
                    
                    
str(Carseats)
'data.frame':   400 obs. of  11 variables:
 $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
 $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
 $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
 $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
 $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
 $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
 $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
 $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
 $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
 $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
 $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...

Manipulando os dados

cad_crianca <- Carseats %>% rename(vendas = Sales, 
                                   preco_comp = CompPrice,
                                   renda = Income,
                                   propaganda = Advertising,
                                   populacao = Population,
                                   preco = Price,
                                   local_prat = ShelveLoc,
                                   idade = Age,
                                   educacao = Education,
                                   urbano = Urban,
                                   eua = US)

cad_crianca <- cad_crianca %>%
  mutate(vendaAlta = ifelse(vendas > 8, "Alta", "Baixa")) %>%
  mutate(vendaAlta = as.factor(vendaAlta)) %>%
  select(-vendas)  # Remover Sales original

# Verificar distribuição
table(cad_crianca$vendaAlta)

 Alta Baixa 
  164   236 
str(cad_crianca)
'data.frame':   400 obs. of  11 variables:
 $ preco_comp: num  138 111 113 117 141 124 115 136 132 132 ...
 $ renda     : num  73 48 35 100 64 113 105 81 110 113 ...
 $ propaganda: num  11 16 10 4 3 13 0 15 0 0 ...
 $ populacao : num  276 260 269 466 340 501 45 425 108 131 ...
 $ preco     : num  120 83 80 97 128 72 108 120 124 124 ...
 $ local_prat: Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
 $ idade     : num  42 65 59 55 38 78 71 67 76 76 ...
 $ educacao  : num  17 10 12 14 13 16 15 10 10 17 ...
 $ urbano    : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
 $ eua       : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
 $ vendaAlta : Factor w/ 2 levels "Alta","Baixa": 1 1 1 2 2 1 2 1 2 2 ...
summary(cad_crianca)
   preco_comp      renda          propaganda       populacao    
 Min.   : 77   Min.   : 21.00   Min.   : 0.000   Min.   : 10.0  
 1st Qu.:115   1st Qu.: 42.75   1st Qu.: 0.000   1st Qu.:139.0  
 Median :125   Median : 69.00   Median : 5.000   Median :272.0  
 Mean   :125   Mean   : 68.66   Mean   : 6.635   Mean   :264.8  
 3rd Qu.:135   3rd Qu.: 91.00   3rd Qu.:12.000   3rd Qu.:398.5  
 Max.   :175   Max.   :120.00   Max.   :29.000   Max.   :509.0  
     preco        local_prat      idade          educacao    urbano   
 Min.   : 24.0   Bad   : 96   Min.   :25.00   Min.   :10.0   No :118  
 1st Qu.:100.0   Good  : 85   1st Qu.:39.75   1st Qu.:12.0   Yes:282  
 Median :117.0   Medium:219   Median :54.50   Median :14.0            
 Mean   :115.8                Mean   :53.32   Mean   :13.9            
 3rd Qu.:131.0                3rd Qu.:66.00   3rd Qu.:16.0            
 Max.   :191.0                Max.   :80.00   Max.   :18.0            
  eua      vendaAlta  
 No :142   Alta :164  
 Yes:258   Baixa:236  
                      
                      
                      
                      

Treino e Teste

set.seed(21)
y <- cad_crianca$vendaAlta
indice_teste <- createDataPartition(y, times = 1, p = 0.2, list = FALSE)

conj_treino <- cad_crianca[-indice_teste,]
conj_teste <- cad_crianca[indice_teste,]

str(conj_treino)
'data.frame':   319 obs. of  11 variables:
 $ preco_comp: num  111 117 141 124 115 132 132 121 117 122 ...
 $ renda     : num  48 100 64 113 105 110 113 78 94 35 ...
 $ propaganda: num  16 4 3 13 0 0 0 9 4 2 ...
 $ populacao : num  260 466 340 501 45 108 131 150 503 393 ...
 $ preco     : num  83 97 128 72 108 124 124 100 94 136 ...
 $ local_prat: Factor w/ 3 levels "Bad","Good","Medium": 2 3 1 1 3 3 3 1 2 3 ...
 $ idade     : num  65 55 38 78 71 76 76 26 50 62 ...
 $ educacao  : num  10 14 13 16 15 10 17 10 13 18 ...
 $ urbano    : Factor w/ 2 levels "No","Yes": 2 2 2 1 2 1 1 1 2 2 ...
 $ eua       : Factor w/ 2 levels "No","Yes": 2 2 1 2 1 1 2 2 2 1 ...
 $ vendaAlta : Factor w/ 2 levels "Alta","Baixa": 1 2 2 1 2 2 2 1 1 2 ...
prop.table(table(conj_treino$alta))
numeric(0)
str(conj_teste)
'data.frame':   81 obs. of  11 variables:
 $ preco_comp: num  138 113 136 125 139 98 115 131 157 118 ...
 $ renda     : num  73 35 81 90 32 118 54 84 53 71 ...
 $ propaganda: num  11 10 15 2 0 0 0 11 0 4 ...
 $ populacao : num  276 269 425 367 176 19 406 29 403 148 ...
 $ preco     : num  120 80 120 131 82 107 128 96 124 114 ...
 $ local_prat: Factor w/ 3 levels "Bad","Good","Medium": 1 3 2 3 2 3 3 3 1 3 ...
 $ idade     : num  42 59 67 35 54 64 42 44 58 80 ...
 $ educacao  : num  17 12 10 18 11 17 17 17 16 13 ...
 $ urbano    : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 2 1 2 2 ...
 $ eua       : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 1 2 2 1 1 ...
 $ vendaAlta : Factor w/ 2 levels "Alta","Baixa": 1 1 1 2 1 2 2 1 2 2 ...
prop.table(table(conj_teste$alta))
numeric(0)

GBM

Ajustando um modelo GBM

set.seed(21)

# Definir grade de parâmetros
tune_grid <- expand.grid(
  n.trees = c(100, 300, 500),
  interaction.depth = c(1, 3, 5),
  shrinkage = c(0.01, 0.1),
  n.minobsinnode = c(5, 10)
)

# Treinamento com validação cruzada
ctrl <- trainControl(method = "cv", number = 5)

gbm_caret <- train(
  vendaAlta ~ .,
  data = conj_treino,
  method = "gbm",
  distribution = "bernoulli",
  trControl = ctrl,
  tuneGrid = tune_grid,
  verbose = FALSE
)

Melhor modelo

# Melhor modelo
gbm_caret$bestTune
   n.trees interaction.depth shrinkage n.minobsinnode
23     300                 1       0.1             10

Verificando os Resultados

# Previsões
pred_class <- predict(gbm_caret, newdata = conj_teste)

# Matriz de confusão
confusionMatrix(pred_class, conj_teste$vendaAlta)
Confusion Matrix and Statistics

          Reference
Prediction Alta Baixa
     Alta    25     6
     Baixa    8    42
                                         
               Accuracy : 0.8272         
                 95% CI : (0.727, 0.9022)
    No Information Rate : 0.5926         
    P-Value [Acc > NIR] : 5.302e-06      
                                         
                  Kappa : 0.6386         
                                         
 Mcnemar's Test P-Value : 0.7893         
                                         
            Sensitivity : 0.7576         
            Specificity : 0.8750         
         Pos Pred Value : 0.8065         
         Neg Pred Value : 0.8400         
             Prevalence : 0.4074         
         Detection Rate : 0.3086         
   Detection Prevalence : 0.3827         
      Balanced Accuracy : 0.8163         
                                         
       'Positive' Class : Alta           
                                         

Curva ROC

# Probabilidades previstas
pred_prob <- predict(gbm_caret, newdata = conj_teste, type = "prob")

# Curva ROC
roc_obj <- roc(response = conj_teste$vendaAlta, predictor = pred_prob$Alta)
plot(roc_obj, col = "blue", lwd = 2, main = "Curva ROC - GBM")

auc(roc_obj)
Area under the curve: 0.9394

Importância das variáveis

# Obter importância
importancia <- varImp(gbm_caret)

# Visualização gráfica
plot(importancia, top = 10, main = "Importância das Variáveis - GBM")