Skip to content

Commit 072837d

Browse files
committed
Added FNN, liquidSVM and neuralnet engines
1 parent 4c31cf6 commit 072837d

File tree

8 files changed

+904
-0
lines changed

8 files changed

+904
-0
lines changed

R/mlp_data.R

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,125 @@ set_pred(
311311
)
312312
)
313313
)
314+
315+
# ------------------------------------------------------------------------------
316+
317+
set_model_engine("mlp", "classification", "neuralnet")
318+
set_model_engine("mlp", "regression", "neuralnet")
319+
set_dependency("mlp", "neuralnet", "neuralnet")
320+
321+
set_model_arg(
322+
model = "mlp",
323+
eng = "neuralnet",
324+
parsnip = "hidden_units",
325+
original = "hidden",
326+
func = list(pkg = "dials", fun = "hidden_units"),
327+
has_submodel = FALSE
328+
)
329+
set_fit(
330+
model = "mlp",
331+
eng = "neuralnet",
332+
mode = "classification",
333+
value = list(
334+
interface = "formula",
335+
protect = c("formula", "data"),
336+
func = c(pkg = "neuralnet", fun = "neuralnet"),
337+
defaults = list(
338+
rep = 1,
339+
linear.output = FALSE
340+
)
341+
)
342+
)
343+
set_fit(
344+
model = "mlp",
345+
eng = "neuralnet",
346+
mode = "regression",
347+
value = list(
348+
interface = "formula",
349+
protect = c("formula", "data"),
350+
func = c(pkg = "neuralnet", fun = "neuralnet"),
351+
defaults = list(
352+
rep = 1,
353+
linear.output = TRUE
354+
)
355+
)
356+
)
357+
set_pred(
358+
model = "mlp",
359+
eng = "neuralnet",
360+
mode = "classification",
361+
type = "class",
362+
value = list(
363+
pre = NULL,
364+
post = function(x, object) object$lvl[apply(x, 1, which.max)],
365+
func = c(pkg = "stats", fun = "predict"),
366+
args = list(
367+
object = quote(object$fit),
368+
newdata = quote(new_data)
369+
)
370+
)
371+
)
372+
set_pred(
373+
model = "mlp",
374+
eng = "neuralnet",
375+
mode = "classification",
376+
type = "prob",
377+
value = list(
378+
pre = NULL,
379+
post = function(x, object) {
380+
colnames(x) <- object$lvl
381+
tibble::as_tibble(x)
382+
},
383+
func = c(pkg = "stats", fun = "predict"),
384+
args = list(
385+
object = quote(object$fit),
386+
newdata = quote(new_data),
387+
prob = TRUE
388+
)
389+
)
390+
)
391+
set_pred(
392+
model = "mlp",
393+
eng = "neuralnet",
394+
mode = "classification",
395+
type = "raw",
396+
value = list(
397+
pre = NULL,
398+
post = NULL,
399+
func = c(pkg = "stats", fun = "predict"),
400+
args = list(
401+
object = quote(object$fit),
402+
newdata = quote(new_data)
403+
)
404+
)
405+
)
406+
set_pred(
407+
model = "mlp",
408+
eng = "neuralnet",
409+
mode = "regression",
410+
type = "numeric",
411+
value = list(
412+
pre = NULL,
413+
post = function(x, object) as.numeric(x),
414+
func = c(pkg = "stats", fun = "predict"),
415+
args = list(
416+
object = quote(object$fit),
417+
newdata = quote(new_data)
418+
)
419+
)
420+
)
421+
set_pred(
422+
model = "mlp",
423+
eng = "neuralnet",
424+
mode = "regression",
425+
type = "raw",
426+
value = list(
427+
pre = NULL,
428+
post = NULL,
429+
func = c(pkg = "stats", fun = "predict"),
430+
args = list(
431+
object = quote(object$fit),
432+
newdata = quote(new_data)
433+
)
434+
)
435+
)

R/nearest_neighbor.R

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,86 @@ knn_by_k <- function(k, object, new_data, type, ...) {
211211
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
212212
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
213213
}
214+
215+
# ------------------------------------------------------------------------------
216+
217+
#' Nearest neighbors using FNN
218+
#'
219+
#' `fnn_train` is a wrapper for `FNN` fast nearest neighbor models
220+
#'
221+
#' @param x a data frame or matrix of predictors.
222+
#' @param y a vector (factor or numeric) or matrix (numeric) of outcome data.
223+
#' @param k a vector (integer) of the number of neighbours to consider.
224+
#' @param algorithm character, one of c("kd_tree", "cover_tree", "brute"),
225+
#' default = "kd_tree"
226+
#' @param ... additional arguments to pass to FNN, currently unused.
227+
#'
228+
#' @return list containing the FNN call
229+
#' @export
230+
fnn_train <- function(x, y = NULL, k = 1, algorithm = "kd_tree", ...) {
231+
232+
# regression
233+
if (is.numeric(y)) {
234+
fun <- "knn.reg"
235+
main_args <- list(
236+
train = rlang::enquo(x),
237+
y = rlang::enquo(y),
238+
k = k,
239+
algorithm = algorithm)
240+
call <- parsnip:::make_call(fun = fun, ns = "FNN", main_args)
241+
rlang::eval_tidy(call, env = rlang::current_env())
242+
243+
# for classification return unevaluated call because FNN:knn
244+
# trains and predicts in same call
245+
} else {
246+
fun <- "knn"
247+
main_args <- list(
248+
train = rlang::enquo(x),
249+
cl = rlang::enquo(y),
250+
k = k,
251+
algorithm = algorithm)
252+
call <- parsnip:::make_call(fun = fun, ns = "FNN", main_args)
253+
list(call = call)
254+
}
255+
}
256+
257+
258+
#' Nearest neighbors prediction using FNN
259+
#'
260+
#' `fnn_pred` is a wrapper for `FNN` fast nearest neighbor models
261+
#'
262+
#' @param object parsnip model spec.
263+
#' @param newdata data.frame or matrix of training data.
264+
#' @param prob logical return predicted probability of the winning class,
265+
#' default = FALSE.
266+
#' @param ... additional arguments to pass to FNN, currently unused.
267+
#'
268+
#' @return data.frame containing the predicted results.
269+
#' @export
270+
fnn_pred <- function(object, newdata, prob = FALSE, ...) {
271+
272+
# modify the call for prediction
273+
object$call$test <- newdata
274+
275+
# regression result
276+
if ("y" %in% names(object$call)) {
277+
res <- rlang::eval_tidy(object$call)
278+
res <- res$pred
279+
280+
# classification result
281+
} else {
282+
object$call$prob <- prob
283+
lvl <- levels(rlang::eval_tidy(object$call$cl))
284+
res <- rlang::eval_tidy(object$call)
285+
286+
# probability for winning class
287+
if (prob == FALSE) {
288+
attributes(res) <- NULL
289+
res <- factor(lvl[res], levels = lvl)
290+
} else {
291+
res <- attr(res, "prob")
292+
}
293+
}
294+
295+
res
296+
}

R/nearest_neighbor_data.R

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,116 @@ set_pred(
172172
)
173173
)
174174
)
175+
176+
# ------------------------------------------------------------------------------
177+
178+
set_model_engine("nearest_neighbor", "classification", "FNN")
179+
set_model_engine("nearest_neighbor", "regression", "FNN")
180+
set_dependency("nearest_neighbor", "FNN", "FNN")
181+
182+
set_model_arg(
183+
model = "nearest_neighbor",
184+
eng = "FNN",
185+
parsnip = "neighbors",
186+
original = "k",
187+
func = list(pkg = "dials", fun = "neighbors"),
188+
has_submodel = FALSE
189+
)
190+
set_fit(
191+
model = "nearest_neighbor",
192+
eng = "FNN",
193+
mode = "regression",
194+
value = list(
195+
interface = "matrix",
196+
protect = c("x", "y"),
197+
func = c(fun = "fnn_train"),
198+
defaults = list()
199+
)
200+
)
201+
set_fit(
202+
model = "nearest_neighbor",
203+
eng = "FNN",
204+
mode = "classification",
205+
value = list(
206+
interface = "matrix",
207+
protect = c("x", "y"),
208+
func = c(fun = "fnn_train"),
209+
defaults = list()
210+
)
211+
)
212+
set_pred(
213+
model = "nearest_neighbor",
214+
eng = "FNN",
215+
mode = "regression",
216+
type = "numeric",
217+
value = list(
218+
pre = NULL,
219+
post = NULL,
220+
func = c(fun = "fnn_pred"),
221+
args = list(
222+
object = quote(object$fit),
223+
newdata = quote(new_data)
224+
)
225+
)
226+
)
227+
set_pred(
228+
model = "nearest_neighbor",
229+
eng = "FNN",
230+
mode = "regression",
231+
type = "raw",
232+
value = list(
233+
pre = NULL,
234+
post = NULL,
235+
func = c(fun = "fnn_pred"),
236+
args = list(
237+
object = quote(object$fit),
238+
newdata = quote(new_data)
239+
)
240+
)
241+
)
242+
set_pred(
243+
model = "nearest_neighbor",
244+
eng = "FNN",
245+
mode = "classification",
246+
type = "class",
247+
value = list(
248+
pre = NULL,
249+
post = NULL,
250+
func = c(fun = "fnn_pred"),
251+
args = list(
252+
object = quote(object$fit),
253+
newdata = quote(new_data)
254+
)
255+
)
256+
)
257+
set_pred(
258+
model = "nearest_neighbor",
259+
eng = "FNN",
260+
mode = "classification",
261+
type = "prob",
262+
value = list(
263+
pre = NULL,
264+
post = function(result, object) tibble::as_tibble(result),
265+
func = c(fun = "fnn_pred"),
266+
args =
267+
list(
268+
object = quote(object$fit),
269+
newdata = quote(new_data),
270+
prob = TRUE
271+
)
272+
)
273+
)
274+
set_pred(
275+
model = "nearest_neighbor",
276+
eng = "FNN",
277+
mode = "classification",
278+
type = "raw",
279+
value = list(
280+
pre = NULL,
281+
post = NULL,
282+
func = c(fun = "fnn_pred"),
283+
args = list(
284+
object = quote(object$fit),
285+
newdata = quote(new_data))
286+
)
287+
)

R/svm_rbf.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,21 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) {
158158
}
159159

160160
}
161+
162+
if (x$engine == "liquidSVM") {
163+
# convert parameter arguments
164+
if (any(arg_names == "sigma")) {
165+
arg_vals$gammas <- rlang::quo(1 / !!sqrt(arg_vals$sigma))
166+
arg_vals$sigma <- NULL
167+
}
168+
169+
if (any(arg_names == "C")) {
170+
arg_vals$lambdas <- arg_vals$C
171+
arg_vals$C <- NULL
172+
}
173+
174+
}
175+
161176
x$method$fit$args <- arg_vals
162177

163178
# worried about people using this to modify the specification

0 commit comments

Comments
 (0)