Pitfalls of Incorrectly Tuned XGBoost Hyperparameters

How an incorrectly tuned base score undermines numerical stability of model predictions across XGBoost versions

Machine learning practitioners have long heard the adage “garbage in, garbage out.” It underscores the importance of basing any modeling effort on a dataset that’s well understood and appropriate for the problem at hand. A related aspect of modeling that’s just as crucial is the need to understand the nuance of the hyperparameters of the modeling algorithms.

Broadly speaking, hyperparameters are model parameters whose values cannot be estimated directly from the training data. They control different aspects of the model training process and need to be specified beforehand. For instance, tree-based algorithms typically have a hyperparameter that controls the depth of the tree that needs to be specified before model training starts. In practice, hyperparameters are often tuned based on prior experience, well-established heuristics, or techniques such as grid search.

At Capital One, hyperparameter tuning is an indispensable step in the overall model building process. A well executed tuning strategy brings benefits such as better predictive performance and faster training time, which in turn helps us extract the most value out of our data. Despite that dedicated effort, however, there are times when we encounter edge cases in the tuning process that end up affecting the modeling outcome in subtle and unexpected ways.

In this post, we will delve into XGBoost, the popular open-source modeling package that implements the gradient boosting machine (GBM) algorithm, and focus on a particular hyperparameter, base_score. Specifically, we will present a scenario in which a model trained with an incorrectly specified base score ends up producing noticeably different predictions across two adjacent versions of the XGBoost package.

By the end of this post, you will learn:

  • The importance of understanding the algorithmic and computational implications when adjusting a model’s hyperparameters.
  • The role of base score in the XGBoost model prediction process.
  • Why it’s a good idea to pick a base score that's commensurate with the scale of the target variable.
  • Why picking a value that's on a vastly different scale might result in subtle and unexpected numerical discrepancies across XGBoost versions.

XGBoost prediction loop and the role of base score

When it comes to the model training process, XGBoost comes with a wide variety of hyperparameters one can tweak to affect the training outcome. One such hyperparameter is base_score which allows the user to provide what is effectively a default guess to the estimator.

Once the model has been trained and is ready to generate predictions, the algorithm will first sum up the predicted residuals from the individual trees, and then add the base score at the end to arrive at the final predicted value. In pseudo-code, this process can be expressed in a simple for-loop, as illustrated below for a single observation.

    predicted_value = 0

for tree in trees:
    # Predicted residual is found by first locating the leaf to which the
    # observation of interest belongs, and then looking up the numerical value
    # for said leaf
    predicted_residual = ...  

    predicted_value += predicted_residual

# Add base score at the end to get final prediction
predicted_value += base_score
  

Alternatively, we could set the running sum equal to the base score to start with, in which case we no longer need to add it at the end.

    # Alternatively, set running sum equal to base score at the start
predicted_value = base_score

for tree in trees:
    predicted_residual = ...
    predicted_value += predicted_residual
  

Prior to XGBoost version 1.3.0, the logic in the prediction loop is such that the running sum starts from zero. In version 1.3.0, however, the behavior is modified such that the running sum is effectively set equal to the base score at the start of the for-loop. Mathematically speaking, these two ways of performing the summation are equivalent. Computationally speaking, however, this isn't always the case. And while this nuance isn't expected to drive meaningful differences in model predictions in the vast majority of use cases, its impact could be magnified and prove problematic depending on the specifics of the modeling setup.

To highlight this last point, we will now present a scenario in which the value of base score and the order in which it’s added to the running sum in the prediction loop end up having a material impact on the numerical stability of model predictions across two adjacent versions of XGBoost.

Impact of base score on numerical stability of model predictions

On a high level, here’s what we will do:

  • Train a model in XGBoost version 1.2.0 using a specific base score, save the model object, and generate predictions on the training data.
  • Load the same model object in XGBoost version 1.3.0, and generate predictions on the same training data.
  • Examine if the two sets of model predictions are identical.

To set up the experiment, we need to create two virtual environments, install XGBoost 1.2.0 into one and 1.3.0 into the other. One way to achieve this is via the conda environment/package manager. Please refer to conda’s official documentation for details.

The training data is completely synthetic. Specifically, the data consists of:

  • 1000 observations.
  • One target variable plus five features, all six of which are randomly generated numbers between zero and one.
    import numpy as np

# Set random seed for reproducibility
np.random.seed(2021)

data = np.random.rand(1000, 6)

# Save to disk for later use
np.savetxt('data.csv', data)
  

We will train a model in XGBoost 1.2.0 and save the fitted model along with model predictions on the training data for later use.

    # Execute this snippet in xgboost 1.2.0 environment
import xgboost as xgb

y, x = data[:, 0], data[:, 1:]

# Set hyperparameters for model training
hyperparams = {
    'base_score': 1000,
    'n_estimators': 500,
    'max_depth': 5,
    'learning_rate': .05
}

# Train model and save
model = xgb.XGBRegressor(**hyperparams)
model.fit(x, y)
model.save_model('model_v120.bin')

# Generate predictions on training data and save
predicted_v120 = model.predict(x)
np.savetxt('predicted_v120.csv', predicted_v120)
  

The choice to use XGBRegressor as the specific model isn’t particularly important here. What is important, however, is the fact that we have set the base score to be much larger than the target variable (recall again the target variable is between zero and one). As we will discuss in a later section, this will affect the numerical stability of model predictions in a subtle and somewhat unexpected way.

Next, we will load the fitted model object into XGBoost 1.3.0 and generate predictions on the training data again.

    # Execute this snippet in xgboost 1.3.0 environment
data = np.loadtxt('data.csv')
x = data[:, 1:]

# Load previously trained model
model = xgb.XGBRegressor()
model.load_model('model_v120.bin')

predicted_v130 = model.predict(x)
predicted_v120 = np.loadtxt('predicted_v120.csv')
  

We now have two sets of model predictions, both of which are generated by the same underlying model object on the same training data. The only difference is the version of XGBoost that’s used -- 1.2.0 vs. 1.3.0.

Intuitively, we would expect the two sets of predictions to be almost identical. To verify whether that’s indeed the case, we plot a histogram showing the absolute percent difference between the two sets.

Absolute percent difference in model predictions between XGBoost versions 1.2.0 and 1.3.0
    Max absolute percent difference: 0.82%
1.2.0: 0.05175781
1.3.0: 0.05133431
  

As seen above, the delta between the two sets of predictions can be as high as 0.82%. While we wouldn’t necessarily expect the predictions to be exactly identical (possibly due to floating-point rounding errors), the observed discrepancy is much greater than anticipated; a delta of this magnitude isn’t something we can simply chalk off as negligible without further investigation. In particular, if model usage is tied to a governance procedure, as is often the case in large enterprises, this discrepancy is sure to raise a few eyebrows from the risk management department.

Let’s re-train the model with a different base score and see if that affects the delta in any way. Indeed, modifying the base score to be 0.5 (default value in XGBoost) results in model predictions that are almost identical across the two XGBoost versions.

    Max percent difference: 0.00%
1.2.0: 0.00867423
1.3.0: 0.00867402
  

We will try a few other base scores and assess their impact on the maximum difference. As illustrated in the table below, there appears to be a positive correlation between base score and the magnitude of the maximum difference. Take the extreme value of 25,000 for example. If we train a model with that base score and compare model predictions across the two XGBoost versions, they could be off from one another for as much as 30%.

BASE SCORE

MAX DIFF

25,000

29.41%

5,000

8.47%

1,000

0.82%

200

0.42%

100

0.28%

20

0.04%

5

0.02%

0.5

0.00%

Is the conclusion then that XGBoost gives unreliable model predictions as soon as we upgrade to a newer version of the package? No. As we will shortly see, our particular modeling setup is such that the impact of an abnormally large base score will be magnified.

Indeed, in our original example we have purposefully chosen the base score to be orders of magnitudes larger than the scale of the target variable. In practice, this could have been the result of a typo when hyperparameters are passed to the XGBoost API, or caused by an oversight when the data scientist rescales the target variable but forgets to adjust the base score accordingly.

Regardless of how it came to be, the key message to take away is that hyperparameters can affect the modeling outcome in subtle and unexpected ways and that it’s important for us to understand both the algorithmic and -- as we will shortly see -- computational implications when modifying their values.

For the issue at hand, we need to investigate why large values of base score appear to result in model predictions that are numerically unstable across XGBoost versions. Thanks to the guaranteed backward compatibility in XGBoost, it’s not uncommon for data scientists to train a model in one version of XGBoost, save that model, and later reuse the same model in a newer version of XGBoost in order to take advantage of the additional features and performance enhancements introduced in the newer version. It would be unfortunate if those improvements came at the expense of numerical instability and unexplained variances in model predictions.

How incorrectly specified base score undermines numerical stability

As alluded to earlier, the numerical discrepancy observed in our original example is a direct result of the fact that we have chosen a base score, 1000, that is orders of magnitude larger than the scale of the target variable, which is between zero and one.

To understand how this issue manifests itself, we need to revisit the pseudo-code for the prediction loop introduced earlier in the post. Specifically, we focus on the modified logic as introduced in XGBoost 1.3.0, whereby the running sum is set equal to the base score at the start of the for-loop.

    # Start running sum at base score. New behavior introduced in 1.3.0
predicted_value = base_score

for tree in trees:
    predicted_residual = ...
    predicted_value += predicted_residual
  

Let’s assume the predicted residual from the very first tree is -50.0341012. We would expect predicted_value to be equal to 949.9658988 after the first iteration of the loop.

    1000.0 - 50.0341012 = 949.9658988
  

Under the hood, a computer performs calculations using floating point arithmetic, similar to how scientific notation works. And while it's not as simple as the equation above, intuitively we would expect it to produce the same answer, as illustrated below. (Computers work in base-two as opposed to base-ten, which is what’s shown in the illustration. However, the same principles apply.)

Subtraction using scientific notation"

What ends up happening in reality though is a little different.

There are only a finite number of bits a computer can use to store floating point numbers. If we have a number with a large number of trailing digits after the decimal point, there might not be enough room to store all of them. After a certain point, some digits would have to be discarded.

Let's assume we are working with 32-bit floats which have a precision limit of approximately seven decimal digits. Because we are starting with a base score that’s so much larger than the predicted residual from the first tree, this effectively causes the smaller number to be "shifted” rightward, causing the last three digits to be discarded in the process.

Loss of precision due to floating point arithmetic

By contrast, if we start the running sum from zero -- the behavior in XGBoost 1.2.0 -- each subsequent addition/subtraction in the for-loop will be performed in a more precise manner, utilizing all the available bits without discarding trailing digits prematurely.

You might be wondering, "What difference does it make if I leave behind a few insignificant digits here and there?" And you'd be right -- it usually doesn't matter. However, when the target variable is small, every bit of precision counts. If we perform a large number of operations, each of which sacrifices just a little bit of precision, these small errors will compound and might produce an end result that's meaningfully different from if we had started the running sum from zero. This phenomenon is a direct result of the non-associativity of floating point calculations -- a property that often catches people by surprise due to the deviation between a mathematical construct and its concrete computational implementation.

How to remediate numerical instability caused by incorrectly specified base score

Depending on your model’s specific use case, the kind of numerical discrepancies we have discussed so far may not even matter. However, if numerical stability and precision are important to you, there are a few options for lessening the numerical discrepancies caused by the floating point arithmetic subtleties described earlier.

The most robust remediation is to re-train the model with an appropriately chosen base score -- one that's commensurate with the scale of the target variable.

In cases where re-training the model isn't feasible, we can eliminate the numerical discrepancy by replicating the prediction logic as found in the earlier version of XGBoost.

    def v120_predict_replica(model, x):
    """Replica of the prediction logic in xgboost version 1.2.0.

    Specifically, base score will be added _at the end_ of the prediction loop,
    rather than at the start. This is achieved via setting `base_margin` to be
    zero, which forces the base score to be ignored.
    https://xgboost.readthedocs.io/en/latest/prediction.html#base-margin
    """
    float32 = np.float32  # xgboost uses 32-bit floats
    base_margin = np.zeros(len(x), dtype=float32)
    dmatrix = xgb.DMatrix(x, base_margin=base_margin)
    base_score = model.get_xgb_params()['base_score']
    booster = model.get_booster()
    predicted_value = booster.predict(dmatrix) + float32(base_score)
    return predicted_value
  

That said, this shouldn't be treated as a permanent fix -- it's simply a stopgap to buy us more time as we make plans for the more robust solution of re-training the model.

Conclusions

In this post, we investigated a scenario in which an XGBoost model trained with an incorrectly specified base score ends up producing numerically unstable predictions across two adjacent XGBoost versions. Even though we focused on just a single hyperparameter, the key message is universal.

Feeding an inappropriate dataset into an algorithm will likely result in an all but useless model. Similarly, attempting to adjust a model’s hyperparameters without fully considering the algorithmic and computational implications might ultimately lead to modeling outcomes that are nuanced, unexpected, and difficult to debug.

If you would like to learn more about the available hyperparameters in XGBoost, the official documentation and this previous Capital One blog post -- “How to Control Your XGBoost Model” -- are good resources to get you started.

Acknowledgements

This post is the culmination of prior work and contributions by my colleagues at Capital One (in alphabetical order):

References and resources


Zhengnan Zhao, Sr. Manager, Data Science

As a data scientist at Capital One, Zhengnan works at the intersection of data science and software engineering, utilizing tools from the PyData ecosystem to power the firm’s valuations analyses with the goal of optimizing credit policy.

Related Content