Skip to contents

Surrogate model containing a single mlr3::LearnerRegr.

Parameters

assert_insample_perf

logical(1)
Should the insample performance of the mlr3::LearnerRegr be asserted after updating the surrogate? If the assertion fails (i.e., the insample performance based on the perf_measure does not meet the perf_threshold), an error is thrown. Default is FALSE.

perf_measure

mlr3::MeasureRegr
Performance measure which should be use to assert the insample performance of the mlr3::LearnerRegr. Only relevant if assert_insample_perf = TRUE. Default is mlr3::mlr_measures_regr.rsq.

perf_threshold

numeric(1)
Threshold the insample performance of the mlr3::LearnerRegr should be asserted against. Only relevant if assert_insample_perf = TRUE. Default is 0.

catch_errors

logical(1)
Should errors during updating the surrogate be caught and propagated to the loop_function which can then handle the failed acquisition function optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation? Default is TRUE.

impute_method

character(1)
Method to impute missing values in the case of updating on an asynchronous bbotk::ArchiveAsync with pending evaluations. Can be "mean" to use mean imputation or "random" to sample values uniformly at random between the empirical minimum and maximum. Default is "random".

Super class

mlr3mbo::Surrogate -> SurrogateLearner

Active bindings

print_id

(character)
Id used when printing.

n_learner

(integer(1))
Returns the number of surrogate models.

assert_insample_perf

(numeric())
Asserts whether the current insample performance meets the performance threshold.

packages

(character())
Set of required packages. A warning is signaled if at least one of the packages is not installed, but loaded (not attached) later on-demand via requireNamespace().

feature_types

(character())
Stores the feature types the surrogate can handle, e.g. "logical", "numeric", or "factor". A complete list of candidate feature types, grouped by task type, is stored in mlr_reflections$task_feature_types.

properties

(character())
Stores a set of properties/capabilities the surrogate has. A complete list of candidate properties, grouped by task type, is stored in mlr_reflections$learner_properties.

predict_type

(character(1))
Retrieves the currently active predict type, e.g. "response".

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

SurrogateLearner$new(learner, archive = NULL, cols_x = NULL, col_y = NULL)

Arguments

learner

(mlr3::LearnerRegr).

archive

(bbotk::Archive | NULL)
bbotk::Archive of the bbotk::OptimInstance.

cols_x

(character() | NULL)
Column id's of variables that should be used as features. By default, automatically inferred based on the archive.

col_y

(character(1) | NULL)
Column id of variable that should be used as a target. By default, automatically inferred based on the archive.


Method predict()

Predict mean response and standard error.

Usage

SurrogateLearner$predict(xdt)

Arguments

xdt

(data.table::data.table())
New data. One row per observation.

Returns

data.table::data.table() with the columns mean and se.


Method clone()

The objects of this class are cloneable with this method.

Usage

SurrogateLearner$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (requireNamespace("mlr3learners") &
    requireNamespace("DiceKriging") &
    requireNamespace("rgenoud")) {
  library(bbotk)
  library(paradox)
  library(mlr3learners)

  fun = function(xs) {
    list(y = xs$x ^ 2)
  }
  domain = ps(x = p_dbl(lower = -10, upper = 10))
  codomain = ps(y = p_dbl(tags = "minimize"))
  objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)

  instance = OptimInstanceBatchSingleCrit$new(
    objective = objective,
    terminator = trm("evals", n_evals = 5))

  xdt = generate_design_random(instance$search_space, n = 4)$data

  instance$eval_batch(xdt)

  learner = default_gp()

  surrogate = srlrn(learner, archive = instance$archive)

  surrogate$update()

  surrogate$learner$model
}
#> 
#> Call:
#> DiceKriging::km(design = data, response = task$truth(), covtype = "matern5_2", 
#>     nugget = 2.83305750865222e-07, optim.method = "gen", control = pv$control)
#> 
#> Trend  coeff.:
#>                Estimate
#>  (Intercept)     5.8590
#> 
#> Covar. type  : matern5_2 
#> Covar. coeff.:
#>                Estimate
#>     theta(x)     1.1710
#> 
#> Variance estimate: 21.95183
#> 
#> Nugget effect : 2.833058e-07
#>