In this example, we demonstrate how to use the
conformalize function to perform conformal prediction and
calculate the out-of-sample coverage rate.
We will generate a simple dataset for demonstration purposes.
We will use a linear model (lm) as the
fit_func and its corresponding predict
function as the predict_func.
library(stats)
# Define fit and predict functions
fit_func <- function(x, y, ...)
{
df <- data.frame(y=y, x) # naming of columns is mandatory for `predict`
print(head(df))
ranger::ranger(y ~ ., data=df, ...)
}
predict_func <- function(obj, newx)
{
colnames(newx) <- paste0("X", 1:ncol(newx)) # mandatory, linked to df in fit_func
predict(object=obj, data=newx)$predictions # only accepts a named newx
}
# Apply conformalize
conformal_model <- misc::conformalize(
x = x,
y = y,
fit_func = fit_func,
predict_func = predict_func,
split_ratio = 0.8,
seed = 123
)## y X1 X2
## 1 0.7480449 0.4447680 0.3233450
## 2 3.2344670 0.8746823 0.3087868
## 3 2.8086211 0.5726334 0.6743764
## 4 2.7359976 0.9373141 0.1054177
## 5 1.5328762 0.2656867 0.4831677
## 6 4.2696182 0.8578277 0.7267025
We will use the predict.conformalize method to generate
predictions and calculate prediction intervals.
# New data for prediction
new_data <- data.frame(X1 = runif(50), X2 = runif(50))
# Predict with split conformal method
predictions <- predict(
conformal_model,
newdata = new_data,
level = 0.95,
method = "split",
predict_func = predict_func
)##
## [1] "object's value:"
## $fit
## Ranger result
##
## Call:
## ranger::ranger(y ~ ., data = df, ...)
##
## Type: Regression
## Number of trees: 500
## Sample size: 160
## Number of independent variables: 2
## Mtry: 1
## Target node size: 5
## Variable importance mode: none
## Splitrule: variance
## OOB prediction error (MSE): 0.2878791
## R squared (OOB): 0.768701
##
## $residuals
## [1] -0.3517141289 -0.6278733391 -0.8919451321 -0.3368367515 -0.5860523628
## [6] -0.3086364683 -0.4745539452 -0.1903110973 -0.1398032514 0.1325369233
## [11] -0.1279973669 -0.1934639470 1.1106438197 0.9771857987 0.1000399597
## [16] 0.8476079199 0.0763709313 1.2510146081 -0.6432988213 -0.5824886831
## [21] 0.1166032311 -0.7442515277 -1.1530115058 -0.1149572542 0.0004330684
## [26] 1.0725410572 0.8389778811 -0.5485431787 -1.5616738419 0.4032726270
## [31] -1.0039911190 -0.6838454247 -0.4422398209 0.6323147429 -0.8067167962
## [36] 1.2611515451 -0.3922378097 0.1163975193 -0.2456218514 -0.7815854456
##
## $sd_residuals
## [1] 0.6915685
##
## $scaled_residuals
## [1] -0.1806242
##
## attr(,"class")
## [1] "conformalize"
## fit lwr upr
## [1,] 0.6598268 -0.5916947 1.911348
## [2,] 0.9518570 -0.2996645 2.203378
## [3,] 2.1487265 0.8972051 3.400248
## [4,] 0.7785231 -0.4729984 2.030045
## [5,] 2.8954852 1.6439638 4.147007
## [6,] 0.8294482 -0.4220733 2.080970
## [1] -0.3517141289 -0.6278733391 -0.8919451321 -0.3368367515 -0.5860523628
## [6] -0.3086364683 -0.4745539452 -0.1903110973 -0.1398032514 0.1325369233
## [11] -0.1279973669 -0.1934639470 1.1106438197 0.9771857987 0.1000399597
## [16] 0.8476079199 0.0763709313 1.2510146081 -0.6432988213 -0.5824886831
## [21] 0.1166032311 -0.7442515277 -1.1530115058 -0.1149572542 0.0004330684
## [26] 1.0725410572 0.8389778811 -0.5485431787 -1.5616738419 0.4032726270
## [31] -1.0039911190 -0.6838454247 -0.4422398209 0.6323147429 -0.8067167962
## [36] 1.2611515451 -0.3922378097 0.1163975193 -0.2456218514 -0.7815854456
The coverage rate is the proportion of true values that fall within the prediction intervals.
# Simulate true values for the new data
true_y <- 3 * new_data$x1 + 2 * new_data$x2 + rnorm(50, sd = 0.5)
# Check if true values fall within the prediction intervals
coverage <- mean(true_y >= predictions[, "lwr"] & true_y <= predictions[, "upr"])
cat("Out-of-sample coverage rate:", coverage)## Out-of-sample coverage rate: NaN
MASS::Boston
DatasetIn this example, we use the MASS::Boston dataset to
demonstrate conformal prediction.
We will use the MASS package to access the
Boston dataset.
## crim zn indus chas nox rm age dis rad tax ptratio black lstat
## 1 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
## 2 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
## 3 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03
## 4 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
## 5 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33
## 6 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21
## medv
## 1 24.0
## 2 21.6
## 3 34.7
## 4 33.4
## 5 36.2
## 6 28.7
We will split the data into training and test sets to ensure they are disjoint.
predict_func <- function(obj, newx)
{
predict(object=obj, data=newx)$predictions # only accepts a named newx
}
# Apply conformalize using the training data
conformal_model_boston <- misc::conformalize(
x = as.matrix(train_data[, -which(names(train_data) == "medv")]),
y = train_data$medv,
fit_func = fit_func,
predict_func = predict_func,
seed = 123
)## y crim zn indus chas nox rm age dis rad tax ptratio black
## 11 15.0 0.22489 12.5 7.87 0 0.524 6.377 94.3 6.3467 5 311 15.2 392.52
## 153 15.3 1.12658 0.0 19.58 1 0.871 5.012 88.0 1.6102 5 403 14.7 343.28
## 10 18.9 0.17004 12.5 7.87 0 0.524 6.004 85.9 6.5921 5 311 15.2 386.71
## 397 12.5 5.87205 0.0 18.10 0 0.693 6.405 96.0 1.6768 24 666 20.2 396.90
## 362 19.9 3.83684 0.0 18.10 0 0.770 6.251 91.1 2.2955 24 666 20.2 350.65
## 35 13.5 1.61282 0.0 8.14 0 0.538 6.096 96.9 3.7598 4 307 21.0 248.31
## lstat
## 11 20.45
## 153 12.12
## 10 17.10
## 397 19.37
## 362 14.19
## 35 20.34
We will use the predict.conformalize method to generate
predictions and calculate prediction intervals for the test set.
# Predict with split conformal method on the test data
predictions_boston <- predict(
conformal_model_boston,
newdata = as.matrix(test_data),
level = 0.95,
method = "split",
predict_func = predict_func
)##
## [1] "object's value:"
## $fit
## Ranger result
##
## Call:
## ranger::ranger(y ~ ., data = df, ...)
##
## Type: Regression
## Number of trees: 500
## Sample size: 202
## Number of independent variables: 13
## Mtry: 3
## Target node size: 5
## Variable importance mode: none
## Splitrule: variance
## OOB prediction error (MSE): 13.88378
## R squared (OOB): 0.8497858
##
## $residuals
## [1] -3.28995224 -0.76321776 2.43128905 -2.66240151 -0.99287952
## [6] 2.70121471 -2.42204429 -1.55376000 -2.56573810 -3.26182053
## [11] -2.14750667 -3.32581873 2.63805000 -3.16368111 -0.18665491
## [16] 4.75706000 1.91775692 0.17052333 23.22278310 -3.92153000
## [21] -2.90377188 -2.01840810 -1.43485155 -5.79080450 -6.06995808
## [26] 0.29719359 0.28491333 0.27125333 -0.99347000 -1.85960667
## [31] 2.10630692 1.28491786 1.08958000 -0.03567000 -0.61526291
## [36] -4.23196833 -0.18508529 0.65849333 -0.65427429 3.12121000
## [41] -1.74895000 -1.48603000 2.20137667 2.93456194 -1.10091333
## [46] -7.04727762 -1.25147789 -2.56449333 -2.72583333 -0.31294785
## [51] 0.79210756 8.54976667 -3.74300005 -2.86267333 -2.72342333
## [56] 0.24300667 1.35786058 -0.21181667 0.88757333 -0.42359000
## [61] 4.69163000 -0.96359889 0.66065359 -0.75930000 -1.30661167
## [66] -4.03423333 5.21864760 -3.41794288 -1.17415476 -5.75970684
## [71] 3.99781333 2.82440860 3.23760137 -2.09188333 -0.43265667
## [76] 0.28025667 1.90285132 -0.66357667 0.37622751 -1.72012333
## [81] -3.60811810 -1.11760463 3.47579692 -0.93810272 5.71293333
## [86] -1.19071621 -1.40754143 -1.02234815 1.73892137 0.66767333
## [91] -1.84796667 -1.89577701 -1.67581062 0.91369779 2.46615571
## [96] 1.82967667 3.23964571 1.33185565 0.59930333 2.67968333
## [101] -1.69234339 -2.04407466 1.67464874 1.35952333 -3.74173287
## [106] -1.14513333 2.09679930 4.16788000 -1.80788667 4.88675667
## [111] 0.61670667 -0.97427667 0.30204090 2.81902667 1.33740701
## [116] 1.35159447 2.71234556 18.28314843 -5.38726210 -4.98795000
## [121] -1.93378921 1.11416333 0.45888857 -0.64968798 -5.63412045
## [126] -1.79377333 2.19376667 0.29618905 8.15903905 -0.81654000
## [131] -0.43899577 0.13727578 -0.73637667 -1.56695667 0.61336333
## [136] -1.61457714 -1.85142641 -0.04208333 2.18082333 -0.05410923
## [141] -1.69643645 -1.05970000 -1.41252805 0.24018524 -2.90524333
## [146] 0.87718000 -4.74593554 1.14160137 -0.36370946 -2.36319333
## [151] -1.52786832 -0.82554667 -0.66698403 -3.51356000 -0.22256070
## [156] -2.46117333 5.40267000 -1.87395333 -1.54967593 -1.27451143
## [161] -1.77721333 -1.18105474 3.10490370 0.25962197 -10.30566648
## [166] 3.36484000 0.55351667 0.76607333 2.61153667 1.55455333
## [171] -1.59643667 1.81080762 4.28905158 -2.09246429 2.04975681
## [176] 7.35960095 -2.43278149 -1.33287801 -2.12767333 2.59084667
## [181] -3.68412954 2.79636778 0.26392333 -2.15853429 -2.89614000
## [186] 2.09732000 -0.29039667 -0.33629000 0.82993794 3.32955000
## [191] 1.82944150 -6.48816667 1.79487190 -3.49737476 -2.38569872
## [196] 0.62226415 9.56466667 3.69821468 0.15197667 1.87070333
## [201] -0.41697375 9.91625667
##
## $sd_residuals
## [1] 3.558924
##
## $scaled_residuals
## [1] 0.01476135
##
## attr(,"class")
## [1] "conformalize"
## fit lwr upr
## [1,] 27.06333 21.007326 33.11933
## [2,] 19.08097 13.024973 25.13697
## [3,] 21.23130 15.175301 27.28730
## [4,] 18.69498 12.638978 24.75098
## [5,] 15.74951 9.693507 21.80551
## [6,] 21.32154 15.265536 27.37754
## [1] -0.3517141289 -0.6278733391 -0.8919451321 -0.3368367515 -0.5860523628
## [6] -0.3086364683 -0.4745539452 -0.1903110973 -0.1398032514 0.1325369233
## [11] -0.1279973669 -0.1934639470 1.1106438197 0.9771857987 0.1000399597
## [16] 0.8476079199 0.0763709313 1.2510146081 -0.6432988213 -0.5824886831
## [21] 0.1166032311 -0.7442515277 -1.1530115058 -0.1149572542 0.0004330684
## [26] 1.0725410572 0.8389778811 -0.5485431787 -1.5616738419 0.4032726270
## [31] -1.0039911190 -0.6838454247 -0.4422398209 0.6323147429 -0.8067167962
## [36] 1.2611515451 -0.3922378097 0.1163975193 -0.2456218514 -0.7815854456
The coverage rate is the proportion of true values in the test set that fall within the prediction intervals.
# True values for the test set
true_y_boston <- test_data$medv
# Check if true values fall within the prediction intervals
coverage_boston <- mean(true_y_boston >= predictions_boston[, "lwr"] & true_y_boston <= predictions_boston[, "upr"])
cat("Out-of-sample coverage rate for Boston dataset:", coverage_boston)## Out-of-sample coverage rate for Boston dataset: 0.9215686