I am using the xgboostExplainer
package and I make my plot using the showWaterfall
function. I would like to make the output of the showWaterfall
plot a little more custom. I first run the following:
library(xgboost)
library(xgboostExplainer)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test
xgb.train.data <- xgb.DMatrix(train$data, label = train$label)
xgb.test.data <- xgb.DMatrix(test$data, label = test$label)
param <- list(objective = "binary:logistic")
xgb.model <- xgboost(param =param, data = xgb.train.data, nrounds = 10)
explained <- buildExplainer(xgb.model, xgb.train.data, type="binary", base_score = 0.5)
pred.breakdown = explainPredictions(xgb.model,
explained,
xgb.test.data)
showWaterfall(xgb.model,
explained,
xgb.test.data, test$data, 2, type = "binary")
Next I go to the function of the showWaterfall
from here and make some modifications to the ggplot
part of the code.
showWaterfall2 = function(xgb.model, explainer, DMatrix, data.matrix, idx, type = "binary", threshold = 0.0001, limits = c(NA, NA)){
breakdown = explainPredictions(xgb.model, explainer, slice(DMatrix,as.integer(idx)))
weight = rowSums(breakdown)
if (type == 'regression'){
pred = weight
}else{
pred = 1/(1+exp(-weight))
}
breakdown_summary = as.matrix(breakdown)[1,]
data_for_label = data.matrix[idx,]
i = order(abs(breakdown_summary),decreasing=TRUE)
breakdown_summary = breakdown_summary[i]
data_for_label = data_for_label[i]
intercept = breakdown_summary[names(breakdown_summary)=='intercept']
data_for_label = data_for_label[names(breakdown_summary)!='intercept']
breakdown_summary = breakdown_summary[names(breakdown_summary)!='intercept']
i_other =which(abs(breakdown_summary)<threshold)
other_impact = 0
if (length(i_other > 0)){
other_impact = sum(breakdown_summary[i_other])
names(other_impact) = 'other'
breakdown_summary = breakdown_summary[-i_other]
data_for_label = data_for_label[-i_other]
}
if (abs(other_impact) > 0){
breakdown_summary = c(intercept, breakdown_summary, other_impact)
data_for_label = c("", data_for_label,"")
labels = paste0(names(breakdown_summary)," = ", data_for_label)
labels[1] = 'intercept'
labels[length(labels)] = 'other'
}else{
breakdown_summary = c(intercept, breakdown_summary)
data_for_label = c("", data_for_label)
labels = paste0(names(breakdown_summary)," = ", data_for_label)
labels[1] = 'intercept'
}
if (!is.null(getinfo(DMatrix,"label"))){
cat("\nActual: ", getinfo(slice(DMatrix,as.integer(idx)),"label"))
}
cat("\nPrediction: ", pred)
cat("\nWeight: ", weight)
cat("\nBreakdown")
cat('\n')
print(breakdown_summary)
if (type == 'regression'){
waterfalls::waterfall(values = breakdown_summary,
rect_text_labels = round(breakdown_summary, 2),
labels = labels,
total_rect_text = round(weight, 2),
calc_total = TRUE,
total_axis_text = "Prediction") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
}else{
inverse_logit_trans <- scales::trans_new("inverse logit",
transform = plogis,
inverse = qlogis)
inverse_logit_labels = function(x){return (1/(1+exp(-x)))}
logit = function(x){return(log(x/(1-x)))}
ybreaks<-logit(seq(2,98,2)/100)
waterfalls::waterfall(values = breakdown_summary,
rect_text_labels = round(breakdown_summary, 2),
labels = labels,
total_rect_text = round(weight, 2),
calc_total = TRUE,
total_axis_text = "Prediction",
#fill_colours = c("blue", "red"),
#fill_by_sign = FALSE
) +
scale_y_continuous(labels = inverse_logit_labels,
breaks = ybreaks, limits = limits) +
scale_color_brewer(palette = "Set1") +
#scale_fill_manual(values = c('darkblue', 'darkred')) +
#scale_color_manual(values = c('darkblue', 'darkred')) +
labs(title = "MyModelTitle",
x = "MyVariables",
y = "ModelProbabilities") +
#coord_flip() +
#theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
theme(
axis.text.x = element_text(angle = 45, hjust = 1),
#aspect.ratio = 1,
axis.line.y = element_blank(),
axis.ticks.y = element_blank(),
strip.background = element_rect(fill = 'darkred'),
panel.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank()
)
}
}
I next run my newely modified function:
showWaterfall2(xgb.model,
explained,
xgb.test.data, test$data, 2, type = "binary")
I would like to make two minor adjustments to the function and ggplot
code. The only part of the function I modified was the following (which corresponds to the ggplot
part of the code):
waterfalls::waterfall(values = breakdown_summary,
rect_text_labels = round(breakdown_summary, 2),
labels = labels,
total_rect_text = round(weight, 2),
calc_total = TRUE,
total_axis_text = "Prediction",
#fill_colours = c("blue", "red"),
#fill_by_sign = FALSE
) +
scale_y_continuous(labels = inverse_logit_labels,
breaks = ybreaks, limits = limits) +
scale_color_brewer(palette = "Set1") +
#scale_fill_manual(values = c('darkblue', 'darkred')) +
#scale_color_manual(values = c('darkblue', 'darkred')) +
labs(title = "MyModelTitle",
x = "MyVariables",
y = "ModelProbabilities") +
#coord_flip() +
#theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
theme(
axis.text.x = element_text(angle = 45, hjust = 1),
#aspect.ratio = 1,
axis.line.y = element_blank(),
axis.ticks.y = element_blank(),
strip.background = element_rect(fill = 'darkred'),
panel.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank()
)
If I uncomment #fill_colours = c("blue", "red")
and #fill_by_sign = FALSE
I can manually color the bars myself. However I like the waterfall
package method of fill_by_sign = TRUE
(negative signs get a different color to positive signs). However nowhere in the documentation of the waterfalls
package (here) does it say how to change the base colors.
How can I chnage ggplot
or waterfalls
base colors?
How can I also make the text (the numbers inside the bars) larger? Adding text = element_text(size = 20)
to the theme()
section doesn't seem to work for me.