Skip to content

Commit

Permalink
Ranger case weights (#2418)
Browse files Browse the repository at this point in the history
* add case weights and repair minprop

* solve case.weights problem
  • Loading branch information
PhilippPro authored and larskotthoff committed Aug 21, 2018
1 parent bc7f986 commit 93d3d3e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 5 additions & 2 deletions R/RLearner_classif_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ makeRLearner.classif.ranger = function() {
}

#' @export
trainLearner.classif.ranger = function(.learner, .task, .subset, .weights = NULL, mtry, mtry.perc, min.node.size, ...) {
trainLearner.classif.ranger = function(.learner, .task, .subset, .weights = NULL, mtry, mtry.perc, min.node.size, case.weights, ...) {
tn = getTaskTargetNames(.task)
if (missing(mtry)) {
if (missing(mtry.perc)) {
Expand All @@ -52,8 +52,11 @@ trainLearner.classif.ranger = function(.learner, .task, .subset, .weights = NULL
min.node.size = 1
}
}
if (missing(case.weights)) {
case.weights = .weights
}
ranger::ranger(formula = NULL, dependent.variable = tn, data = getTaskData(.task, .subset),
probability = (.learner$predict.type == "prob"), case.weights = .weights, mtry = mtry, min.node.size = min.node.size, ...)
probability = (.learner$predict.type == "prob"), case.weights = case.weights, mtry = mtry, min.node.size = min.node.size, ...)
}

#' @export
Expand Down
7 changes: 5 additions & 2 deletions R/RLearner_regr_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ makeRLearner.regr.ranger = function() {
}

#' @export
trainLearner.regr.ranger = function(.learner, .task, .subset, .weights = NULL, keep.inbag = NULL, mtry, mtry.perc, ...) {
trainLearner.regr.ranger = function(.learner, .task, .subset, .weights = NULL, keep.inbag = NULL, mtry, mtry.perc, case.weights, ...) {
tn = getTaskTargetNames(.task)
if (missing(mtry)) {
if (missing(mtry.perc)) {
Expand All @@ -47,10 +47,13 @@ trainLearner.regr.ranger = function(.learner, .task, .subset, .weights = NULL, k
mtry = max(1, floor(mtry.perc * getTaskNFeats(.task)))
}
}
if (missing(case.weights)) {
case.weights = .weights
}
keep.inbag = if (is.null(keep.inbag)) FALSE else keep.inbag
keep.inbag = if (.learner$predict.type == "se") TRUE else keep.inbag
ranger::ranger(formula = NULL, dependent.variable = tn, data = getTaskData(.task, .subset),
case.weights = .weights, keep.inbag = keep.inbag, mtry = mtry, ...)
case.weights = case.weights, keep.inbag = keep.inbag, mtry = mtry, ...)
}

#' @export
Expand Down

0 comments on commit 93d3d3e

Please sign in to comment.