conformalize matrix interface

library(misc)

Example: Conformal Prediction with Out-of-Sample Coverage

In this example, we demonstrate how to use the conformalize function to perform conformal prediction and calculate the out-of-sample coverage rate.

Simulated Data

We will generate a simple dataset for demonstration purposes.

set.seed(123)
n <- 200
x <- matrix(runif(n * 2), ncol = 2)
y <- 3 * x[, 1] + 2 * x[, 2] + rnorm(n, sd = 0.5)
data <- data.frame(x1 = x[, 1], x2 = x[, 2], y = y)

Fit Conformal Model

We will use a linear model (lm) as the fit_func and its corresponding predict function as the predict_func.

library(stats)

# Define fit and predict functions
fit_func <- function(x, y, ...)
{
  df <- data.frame(y=y, x) # naming of columns is mandatory for `predict`
  print(head(df))
  ranger::ranger(y ~ ., data=df, ...)
}

predict_func <- function(obj, newx)
{
  colnames(newx) <- paste0("X", 1:ncol(newx)) # mandatory, linked to df in fit_func
  predict(object=obj, data=newx)$predictions # only accepts a named newx
}

# Apply conformalize
conformal_model <- misc::conformalize(  
  x = x,
  y = y,
  fit_func = fit_func,
  predict_func = predict_func,
  split_ratio = 0.8,
  seed = 123
)
##           y        X1        X2
## 1 0.7480449 0.4447680 0.3233450
## 2 3.2344670 0.8746823 0.3087868
## 3 2.8086211 0.5726334 0.6743764
## 4 2.7359976 0.9373141 0.1054177
## 5 1.5328762 0.2656867 0.4831677
## 6 4.2696182 0.8578277 0.7267025

Generate Predictions and Prediction Intervals

We will use the predict.conformalize method to generate predictions and calculate prediction intervals.

# New data for prediction
new_data <- data.frame(X1 = runif(50), X2 = runif(50))

# Predict with split conformal method
predictions <- predict(
  conformal_model,
  newdata = new_data,
  level = 0.95,
  method = "split",
  predict_func = predict_func
)
## 
## [1] "object's value:"
## $fit
## Ranger result
## 
## Call:
##  ranger::ranger(y ~ ., data = df, ...) 
## 
## Type:                             Regression 
## Number of trees:                  500 
## Sample size:                      160 
## Number of independent variables:  2 
## Mtry:                             1 
## Target node size:                 5 
## Variable importance mode:         none 
## Splitrule:                        variance 
## OOB prediction error (MSE):       0.2878791 
## R squared (OOB):                  0.768701 
## 
## $residuals
##  [1] -0.3517141289 -0.6278733391 -0.8919451321 -0.3368367515 -0.5860523628
##  [6] -0.3086364683 -0.4745539452 -0.1903110973 -0.1398032514  0.1325369233
## [11] -0.1279973669 -0.1934639470  1.1106438197  0.9771857987  0.1000399597
## [16]  0.8476079199  0.0763709313  1.2510146081 -0.6432988213 -0.5824886831
## [21]  0.1166032311 -0.7442515277 -1.1530115058 -0.1149572542  0.0004330684
## [26]  1.0725410572  0.8389778811 -0.5485431787 -1.5616738419  0.4032726270
## [31] -1.0039911190 -0.6838454247 -0.4422398209  0.6323147429 -0.8067167962
## [36]  1.2611515451 -0.3922378097  0.1163975193 -0.2456218514 -0.7815854456
## 
## $sd_residuals
## [1] 0.6915685
## 
## $scaled_residuals
## [1] -0.1806242
## 
## attr(,"class")
## [1] "conformalize"
head(predictions)
##            fit        lwr      upr
## [1,] 0.6598268 -0.5916947 1.911348
## [2,] 0.9518570 -0.2996645 2.203378
## [3,] 2.1487265  0.8972051 3.400248
## [4,] 0.7785231 -0.4729984 2.030045
## [5,] 2.8954852  1.6439638 4.147007
## [6,] 0.8294482 -0.4220733 2.080970
residuals(conformal_model)
##  [1] -0.3517141289 -0.6278733391 -0.8919451321 -0.3368367515 -0.5860523628
##  [6] -0.3086364683 -0.4745539452 -0.1903110973 -0.1398032514  0.1325369233
## [11] -0.1279973669 -0.1934639470  1.1106438197  0.9771857987  0.1000399597
## [16]  0.8476079199  0.0763709313  1.2510146081 -0.6432988213 -0.5824886831
## [21]  0.1166032311 -0.7442515277 -1.1530115058 -0.1149572542  0.0004330684
## [26]  1.0725410572  0.8389778811 -0.5485431787 -1.5616738419  0.4032726270
## [31] -1.0039911190 -0.6838454247 -0.4422398209  0.6323147429 -0.8067167962
## [36]  1.2611515451 -0.3922378097  0.1163975193 -0.2456218514 -0.7815854456

Calculate Out-of-Sample Coverage Rate

The coverage rate is the proportion of true values that fall within the prediction intervals.

# Simulate true values for the new data
true_y <- 3 * new_data$x1 + 2 * new_data$x2 + rnorm(50, sd = 0.5)

# Check if true values fall within the prediction intervals
coverage <- mean(true_y >= predictions[, "lwr"] & true_y <= predictions[, "upr"])

cat("Out-of-sample coverage rate:", coverage)
## Out-of-sample coverage rate: NaN

Results

  • The prediction intervals are calculated using the split conformal method.
  • The out-of-sample coverage rate is displayed, which should be close to the specified confidence level (e.g., 0.95).

Example: Conformal Prediction with the MASS::Boston Dataset

In this example, we use the MASS::Boston dataset to demonstrate conformal prediction.

Load the Data

We will use the MASS package to access the Boston dataset.

library(MASS)

# Load the Boston dataset
data(Boston)

# Inspect the dataset
head(Boston)
##      crim zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 1 0.00632 18  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98
## 2 0.02731  0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14
## 3 0.02729  0  7.07    0 0.469 7.185 61.1 4.9671   2 242    17.8 392.83  4.03
## 4 0.03237  0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94
## 5 0.06905  0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33
## 6 0.02985  0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21
##   medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7

Split the Data

We will split the data into training and test sets to ensure they are disjoint.

set.seed(123)
n <- nrow(MASS::Boston)
train_indices <- sample(seq_len(n), size = floor(0.8 * n))
train_data <- MASS::Boston[train_indices, ]
test_data <- MASS::Boston[-train_indices, ]

Fit Conformal Model

predict_func <- function(obj, newx)
{
  predict(object=obj, data=newx)$predictions # only accepts a named newx
}


# Apply conformalize using the training data
conformal_model_boston <- misc::conformalize(
  x = as.matrix(train_data[, -which(names(train_data) == "medv")]),
  y = train_data$medv,
  fit_func = fit_func,
  predict_func = predict_func,
  seed = 123
)
##        y    crim   zn indus chas   nox    rm  age    dis rad tax ptratio  black
## 11  15.0 0.22489 12.5  7.87    0 0.524 6.377 94.3 6.3467   5 311    15.2 392.52
## 153 15.3 1.12658  0.0 19.58    1 0.871 5.012 88.0 1.6102   5 403    14.7 343.28
## 10  18.9 0.17004 12.5  7.87    0 0.524 6.004 85.9 6.5921   5 311    15.2 386.71
## 397 12.5 5.87205  0.0 18.10    0 0.693 6.405 96.0 1.6768  24 666    20.2 396.90
## 362 19.9 3.83684  0.0 18.10    0 0.770 6.251 91.1 2.2955  24 666    20.2 350.65
## 35  13.5 1.61282  0.0  8.14    0 0.538 6.096 96.9 3.7598   4 307    21.0 248.31
##     lstat
## 11  20.45
## 153 12.12
## 10  17.10
## 397 19.37
## 362 14.19
## 35  20.34

Generate Predictions and Prediction Intervals

We will use the predict.conformalize method to generate predictions and calculate prediction intervals for the test set.

# Predict with split conformal method on the test data
predictions_boston <- predict(
  conformal_model_boston,
  newdata = as.matrix(test_data),
  level = 0.95,
  method = "split",
  predict_func = predict_func
)
## 
## [1] "object's value:"
## $fit
## Ranger result
## 
## Call:
##  ranger::ranger(y ~ ., data = df, ...) 
## 
## Type:                             Regression 
## Number of trees:                  500 
## Sample size:                      202 
## Number of independent variables:  13 
## Mtry:                             3 
## Target node size:                 5 
## Variable importance mode:         none 
## Splitrule:                        variance 
## OOB prediction error (MSE):       13.88378 
## R squared (OOB):                  0.8497858 
## 
## $residuals
##   [1]  -3.28995224  -0.76321776   2.43128905  -2.66240151  -0.99287952
##   [6]   2.70121471  -2.42204429  -1.55376000  -2.56573810  -3.26182053
##  [11]  -2.14750667  -3.32581873   2.63805000  -3.16368111  -0.18665491
##  [16]   4.75706000   1.91775692   0.17052333  23.22278310  -3.92153000
##  [21]  -2.90377188  -2.01840810  -1.43485155  -5.79080450  -6.06995808
##  [26]   0.29719359   0.28491333   0.27125333  -0.99347000  -1.85960667
##  [31]   2.10630692   1.28491786   1.08958000  -0.03567000  -0.61526291
##  [36]  -4.23196833  -0.18508529   0.65849333  -0.65427429   3.12121000
##  [41]  -1.74895000  -1.48603000   2.20137667   2.93456194  -1.10091333
##  [46]  -7.04727762  -1.25147789  -2.56449333  -2.72583333  -0.31294785
##  [51]   0.79210756   8.54976667  -3.74300005  -2.86267333  -2.72342333
##  [56]   0.24300667   1.35786058  -0.21181667   0.88757333  -0.42359000
##  [61]   4.69163000  -0.96359889   0.66065359  -0.75930000  -1.30661167
##  [66]  -4.03423333   5.21864760  -3.41794288  -1.17415476  -5.75970684
##  [71]   3.99781333   2.82440860   3.23760137  -2.09188333  -0.43265667
##  [76]   0.28025667   1.90285132  -0.66357667   0.37622751  -1.72012333
##  [81]  -3.60811810  -1.11760463   3.47579692  -0.93810272   5.71293333
##  [86]  -1.19071621  -1.40754143  -1.02234815   1.73892137   0.66767333
##  [91]  -1.84796667  -1.89577701  -1.67581062   0.91369779   2.46615571
##  [96]   1.82967667   3.23964571   1.33185565   0.59930333   2.67968333
## [101]  -1.69234339  -2.04407466   1.67464874   1.35952333  -3.74173287
## [106]  -1.14513333   2.09679930   4.16788000  -1.80788667   4.88675667
## [111]   0.61670667  -0.97427667   0.30204090   2.81902667   1.33740701
## [116]   1.35159447   2.71234556  18.28314843  -5.38726210  -4.98795000
## [121]  -1.93378921   1.11416333   0.45888857  -0.64968798  -5.63412045
## [126]  -1.79377333   2.19376667   0.29618905   8.15903905  -0.81654000
## [131]  -0.43899577   0.13727578  -0.73637667  -1.56695667   0.61336333
## [136]  -1.61457714  -1.85142641  -0.04208333   2.18082333  -0.05410923
## [141]  -1.69643645  -1.05970000  -1.41252805   0.24018524  -2.90524333
## [146]   0.87718000  -4.74593554   1.14160137  -0.36370946  -2.36319333
## [151]  -1.52786832  -0.82554667  -0.66698403  -3.51356000  -0.22256070
## [156]  -2.46117333   5.40267000  -1.87395333  -1.54967593  -1.27451143
## [161]  -1.77721333  -1.18105474   3.10490370   0.25962197 -10.30566648
## [166]   3.36484000   0.55351667   0.76607333   2.61153667   1.55455333
## [171]  -1.59643667   1.81080762   4.28905158  -2.09246429   2.04975681
## [176]   7.35960095  -2.43278149  -1.33287801  -2.12767333   2.59084667
## [181]  -3.68412954   2.79636778   0.26392333  -2.15853429  -2.89614000
## [186]   2.09732000  -0.29039667  -0.33629000   0.82993794   3.32955000
## [191]   1.82944150  -6.48816667   1.79487190  -3.49737476  -2.38569872
## [196]   0.62226415   9.56466667   3.69821468   0.15197667   1.87070333
## [201]  -0.41697375   9.91625667
## 
## $sd_residuals
## [1] 3.558924
## 
## $scaled_residuals
## [1] 0.01476135
## 
## attr(,"class")
## [1] "conformalize"
head(predictions_boston)
##           fit       lwr      upr
## [1,] 27.06333 21.007326 33.11933
## [2,] 19.08097 13.024973 25.13697
## [3,] 21.23130 15.175301 27.28730
## [4,] 18.69498 12.638978 24.75098
## [5,] 15.74951  9.693507 21.80551
## [6,] 21.32154 15.265536 27.37754
residuals(conformal_model)
##  [1] -0.3517141289 -0.6278733391 -0.8919451321 -0.3368367515 -0.5860523628
##  [6] -0.3086364683 -0.4745539452 -0.1903110973 -0.1398032514  0.1325369233
## [11] -0.1279973669 -0.1934639470  1.1106438197  0.9771857987  0.1000399597
## [16]  0.8476079199  0.0763709313  1.2510146081 -0.6432988213 -0.5824886831
## [21]  0.1166032311 -0.7442515277 -1.1530115058 -0.1149572542  0.0004330684
## [26]  1.0725410572  0.8389778811 -0.5485431787 -1.5616738419  0.4032726270
## [31] -1.0039911190 -0.6838454247 -0.4422398209  0.6323147429 -0.8067167962
## [36]  1.2611515451 -0.3922378097  0.1163975193 -0.2456218514 -0.7815854456

Calculate Out-of-Sample Coverage Rate 1

The coverage rate is the proportion of true values in the test set that fall within the prediction intervals.

# True values for the test set
true_y_boston <- test_data$medv

# Check if true values fall within the prediction intervals
coverage_boston <- mean(true_y_boston >= predictions_boston[, "lwr"] & true_y_boston <= predictions_boston[, "upr"])

cat("Out-of-sample coverage rate for Boston dataset:", coverage_boston)
## Out-of-sample coverage rate for Boston dataset: 0.9215686