Fast optimization of classification thresholds


Binary classification problems (target/non-target) are often modeled as a pair (f, \theta) where f : \mathbb{R}^D \to [0, 1] is our model, which maps input vectors to scores, and \theta \in [0, 1] is our threshold, such that we predict x to be of target class iff f(x) \geq \theta . Otherwise, we predict it to be of non-target class.

Representing dataset, scores and threshold
Model predicts “target” class
Model predicts “non-target” class

The threshold \theta is usually set to 0.5 , but this needs not be the best choice. If f can be assumed to be well-calibrated1 –which usually isn’t the case, though– and the cost of a false positive and a false negative are known constants (call them C_{FP} and C_{FN}), then a theoretical cost-minimizing threshold can be computed with the formula \frac{C_{FP}}{C_{FP} + C_{FN}} . That means 0.5 is the optimal threshold under these conditions only when C_{FP} = C_{FN} . For imbalanced classification, usually that equality does not hold. And even in balanced settings, it’s more likely that it won’t.

Now, that formula won’t save us every time. We need something else. Sometimes our model is not well-calibrated, as even when we use calibration methods on it, we may not trust that the result is actually properly calibrated. In other cases the costs for false positives and false negatives may not be known or may not be constant. What’s the cost of saying a dog is a cat if you’re a photo-categorizing software? Or the cost of wrongly tagging a sports news article as politics? It’s not always easy to quantify those costs.

So, suppose we wanted to use some other strategy to select the classification threshold. We could try to find instead the threshold \theta that maximizes the F_1 score in our validation set, or some other evaluation metric. How would we go about doing that?

The rest of this blogpost deals with finding thresholds that maximize (or minimize) classification metrics in time \mathcal{O}(Nlog(N)) where N is the size of the validation set used for optimizing the threshold.

The algorithm itself is very simple, but analyzing it seriously can help paint a clearer picture on how changing thresholds affect our metrics, and about what’s going on when we choose to use thresholded metrics for our classification models.

The general idea of how you get to optimize these thresholds fast is used in a bunch applications in Data Science. Something very similar to what I’m describing is used to find splits when building CART trees by scikit-learn2 (the similarity to this problem is pretty obvious if you remember how CART works) and is also used when building the precision-recall curve and related metrics like average precision, also by scikit-learn.

Algorithm for general threshold optimization

Visualization

Let’s start by visualizing examples of the functions we’re trying to optimize.

Remember we’re trying to find the best threshold so that the resulting target/no-target classification using that threshold maximizes some evaluation metric like accuracy, F1, etc.3

Here the blue line is the function itself. The x axis represents all possible thresholds, and the y axis has the values of the evaluation metric being considered.

We also added green circles and red crosses. They represent our output scores for items in our dataset. In these examples the dataset –one would usually use the validation set– has 5 elements of target class and 5 of non-target class.4 These dots and crosses are not really part of the function itself, but visualizing the dataset together with the function helps (me) understand what’s going on.

Summarizing: the blue line is the evaluation function as we change the threshold. Circles and crosses mark scores for elements in our dataset of target and non-target class respectively.

Note how the value of our metric only changes when we move the threshold “over” some score in the dataset, changing the TP or FP count (more on that later). This will be true of all metrics that can be defined as a function of TP, FP, TN and FN.5

Example function of thresholds to accuracy
Example function of thresholds to precision

Algorithm description

Looking at the shapes of the functions can be pretty discouraging. They don’t belong to any of the well-known function types that are easy to optimize. They’re clearly not linear or affine, they’re not convex, in fact, not even continuous, and in the entire domain the derivative is either undefined or 0, so all gradient-based methods are out.

But, they have a very lovely property which is that we only need to evaluate a finite number of points to know all values of the function. This can be seen in the example functions, and isn’t hard to prove more rigorously.

That property of only having to evaluate a finite number of thresholds takes us all the way from not even knowing how to find the best threshold to having an obvious quadratic time algorithm: just try every possible value and pick the best one! Shown here as python-ish pseudo-code (remember vectorization will come later).

def find_threshold_quad(y_proba, y_val, eval_func):
    """Find a threshold that maximizes eval_func."""
    evals = np.empty_like(y_proba)
    for i, threshold in enumerate(y_proba):
        # confusion_matrix runs in linear time
        # w.r.t dataset size.
        tp, fp, tn, fn = confusion_matrix(
            y_val, y_proba >= threshold
        )
        # assume eval_func runs in constant time!
        evals[i] = eval_func(tp, fp, tn, fn)
    best_threshold_idx = np.argmax(evals)
    return y_proba[best_threshold_idx]

Something to note here: we’re saying that eval_func should run in constant time but scikit-learn style thresholded evaluation functions (like sklearn.metrics.precision_score) take linear time usually. How can that be?

That’s because scikit-learn style functions count the TPs, FPs, etc. as part of the evaluation function itself. For the purpose of this blogpost we are using evaluation functions that take the TP, FP, TN and FN as pre-computed values. We will assume the eval_funcs are constant time functions for the rest of this blogpost. Accuracy, precision, recall, F1, g-mean, and most other thresholded metrics have obvious constant time implementations as functions of TP, FP, TN and FN. For example, accuracy is just \frac{TP + TN}{TP + TN + FP + FN} , so it always involves 3 sums and a division.

We can improve on the above algorithm’s time complexity, though. The trick is the following: if our list of scores is ordered, then there is no need to fully recompute the TPs, FNs, etc. for every single threshold, they don’t really change that much when we move from smallest to largest (or largest to smallest).

Assume all scores are different from each other to simplify the argument and the pseudo-code. We’ll deal with the problem of repeated scores later when we vectorize our code.

For the smallest threshold, i.e. when we are using the smallest predicted score as the threshold, everything is classified as target class, as every score (including itself) is greater than or equal to the smallest one. That means TN and FN are 0, while TP and FP are the count of elements of ground-truth target class and non-target class respectively.

Minimum threshold

While we move to the right, for each actual we pass, we only need to update those numbers (TP, FP, etc.) by incrementing or decrementing them based on the class of the previous item. Each time we “pass over” a score, if that was of class target, we lose a TP. If it was of class non-target, we gain a TN.

Second threshold

This realization that we don’t need to recompute TP , etc. from scratch every time let’s us get rid of our linear-time confusion_matrix function, and replace it with constant-time modification of our tp, fp, tn and fn variables.

We don’t quite manage to drop time complexity to linear time over the entire optimization because we need to sort the scores to do this. After we sort them, the actual optimization itself is just linear time, though.

def find_threshold_loglinear(y_proba, y_val, eval_func):
    """Find a threshold that maximizes eval_func."""
    sort_idx = np.argsort(y_proba)
    y_proba, y_val = y_proba[sort_idx], y_val[sort_idx]
    evals = np.empty_like(y_proba)
    tp = y_val.sum()
    fp = (y_val == 0).sum()
    tn, fn = 0, 0
    for i, threshold in enumerate(y_proba):
        # assume eval_func runs in constant time.
        evals[i] = eval_func(tp, fp, tn, fn)
        # Note we modify _after_ evaluation.
        if y_val[i] == 1:
            tp -= 1
            fn += 1
        else:
            fp -= 1
            tn += 1
    best_threshold_idx = np.argmax(evals)
    return y_proba[best_threshold_idx]

That’s pretty much all there is to it. In the following section we’ll present a vectorized numpy implementation of this same algorithm, which also covers some edge-cases we ignored here.

In particular, think what happens when the best threshold is one that classifies everything as non-target. Is that case handled in the above pseudo-code? Also think about the case where a score is repeated in the dataset. Are we handling those cases correctly? How could that be fixed?

Vectorized numpy implementation

def find_threshold(y_val, y_proba, eval_func_vec):
    """Vectorized threshold optimization.
    
    This uses `np.cumsum` to vectorize the conditional
    for-loops we used in the naïve version.
    """
    # We start by sorting the probas in descending
    # order. (That's what the `-` is for)
    sort_idx = np.argsort(-y_proba)
    y_proba, y_val = y_proba[sort_idx], y_val[sort_idx]
    # We need to add an np.inf as a possible proba to
    # handle the case where the best threshold is one
    # where all predictions are of _non-target_.
    y_proba_completed = np.concatenate(
        [[np.inf], y_proba]
    )
    unique_idx = np.concatenate(
        [
            np.where(np.diff(y_proba_completed))[0], 
            # Due to the way np.diff works, we need
            # to add the final element by hand.
            [y_val.size]
        ]
    )
    # These 0s here are added for getting the right
    # value with `cumsum` for `np.inf`.
    y_val_completed = np.concatenate([[0], y_val])
    y_val_completed_reverse = np.concatenate(
        [[0], y_val == 0]
    )
    tps = np.cumsum(y_val_completed)
    fns = np.sum(y_val_completed) - tps
    fps = np.cumsum(y_val_completed_reverse)
    tns = np.sum(y_val_completed_reverse) - fps
    # This eval_func_vec should compute the metrics
    # for several TP, FP, TN, FN configurations at once.
    values = eval_func_vec(
        tps[unique_idx], fps[unique_idx],
        tns[unique_idx], fns[unique_idx],
    )
    # nanargmax for dealing with undefined metrics,
    # like precision above highest score.
    best_unique_idx = np.nanargmax(values)
    best_idx = unique_idx[best_unique_idx]
    return y_proba_completed[best_idx]

The last question to answer is: how do you implement an eval_func_vec? Due to the overloaded syntax used in numpy, it’s usually very simple. A couple of examlpes follow.

def accuracy_vec(tps, fps, tns, fns):
    return (tps + tns) / (tps + tns + fps + fns)
def geometric_mean_vec(tps, fps, tns, fns):
    sensitivities = tps / (tps + fns)
    specificities = tns / (tns + fps)
    return np.sqrt(sensitivities * specificities)

The vectorized find_threshold above could be used for finding splits in algorithms for building decision trees, or, with small modifications, it could be used for creating precision-recall curves and other metric-to-metric curves fast.

Similar cumsum-type optimizations can be applied to time series data when you’re counting items before some time, accumulated averages, etc. The idea behind this algorithm (and vectorization) is useful in a variety of situations in Data Science, and is worth keeping in mind.

Footnotes

  1. Well-calibrated intuitively means that if we take the set A := \{x: f(x) \approx p\} , then the ratio of items of actual target class in A should be close to p . In other words, when a model is well-calibrated, f(x) can be interpreted as the probability that x is of class target.
  2. Here’s a link to an example place where sklearn uses the idea for trees and here’s how they use it for computing precision-recall curves and metrics that depends on counting TPs and FPs at all thresholds. Both those pieces of code have a lot of surrounding type-conversion, edge-case-matching, optional-parameter-considering code, so looking at the implementation in this blogpost first may help separate the important parts of that from the noise.
  3. When I say “the best threshold” I actually mean “one of the thresholds that are the best”. In other words, a threshold such that no other threshold is strictly better.
  4. In case the example is misleading, I want to clarify that the datasets needn’t be balanced. It happens to be balanced in these examples, but I could have chosen an imbalanced dataset as well.
  5. Alternatively, in terms of TP, FP, positive count, and negative count. Or other things from which we can compute the confusion matrix.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

Create a website or blog at WordPress.com

%d bloggers like this: