Skip to content

Commit 41693da

Browse files
authored
Merge pull request #300 from stevenpawley/master
Addition of liquidSVM engine to svm_rbf
2 parents 4c31cf6 + a1a5078 commit 41693da

File tree

3 files changed

+353
-0
lines changed

3 files changed

+353
-0
lines changed

R/svm_rbf.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,15 @@
3333
#' following _engines_:
3434
#' \itemize{
3535
#' \item \pkg{R}: `"kernlab"` (the default)
36+
#' \item \pkg{R}: `"liquidSVM"`
3637
#' }
3738
#'
39+
#' Note that models created using the `liquidSVM` engine cannot be saved like
40+
#' conventional R objects. The `fit` slot of the `model_fit` object has to be
41+
#' saved separately using the `liquidSVM::write.liquidSVM()` function. Likewise
42+
#' to restore a model, the `fit` slot has to be replaced with the model that is
43+
#' read using the `liquidSVM::read.liquidSVM()` function.
44+
#'
3845
#' @includeRmd man/rmd/svm-rbf.Rmd details
3946
#'
4047
#' @importFrom purrr map_lgl
@@ -158,6 +165,21 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) {
158165
}
159166

160167
}
168+
169+
if (x$engine == "liquidSVM") {
170+
# convert parameter arguments
171+
if (any(arg_names == "sigma")) {
172+
arg_vals$gammas <- rlang::quo(1 / !!sqrt(arg_vals$sigma))
173+
arg_vals$sigma <- NULL
174+
}
175+
176+
if (any(arg_names == "C")) {
177+
arg_vals$lambdas <- arg_vals$C
178+
arg_vals$C <- NULL
179+
}
180+
181+
}
182+
161183
x$method$fit$args <- arg_vals
162184

163185
# worried about people using this to modify the specification

R/svm_rbf_data.R

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,139 @@ set_pred(
140140
)
141141
)
142142

143+
# ------------------------------------------------------------------------------
144+
145+
set_model_engine("svm_rbf", "classification", "liquidSVM")
146+
set_model_engine("svm_rbf", "regression", "liquidSVM")
147+
set_dependency("svm_rbf", "liquidSVM", "liquidSVM")
148+
149+
set_model_arg(
150+
model = "svm_rbf",
151+
eng = "liquidSVM",
152+
parsnip = "cost",
153+
original = "lambdas",
154+
func = list(pkg = "dials", fun = "cost"),
155+
has_submodel = FALSE
156+
)
157+
set_model_arg(
158+
model = "svm_rbf",
159+
eng = "liquidSVM",
160+
parsnip = "rbf_sigma",
161+
original = "gammas",
162+
func = list(pkg = "dials", fun = "rbf_sigma"),
163+
has_submodel = FALSE
164+
)
165+
set_fit(
166+
model = "svm_rbf",
167+
eng = "liquidSVM",
168+
mode = "regression",
169+
value = list(
170+
interface = "matrix",
171+
protect = c("x", "y"),
172+
func = c(pkg = "liquidSVM", fun = "svm"),
173+
defaults = list(
174+
folds = 1,
175+
threads = 0
176+
)
177+
)
178+
)
179+
set_fit(
180+
model = "svm_rbf",
181+
eng = "liquidSVM",
182+
mode = "classification",
183+
value = list(
184+
interface = "matrix",
185+
protect = c("x", "y"),
186+
func = c(pkg = "liquidSVM", fun = "svm"),
187+
defaults = list(
188+
folds = 1,
189+
threads = 0
190+
)
191+
)
192+
)
193+
set_pred(
194+
model = "svm_rbf",
195+
eng = "liquidSVM",
196+
mode = "regression",
197+
type = "numeric",
198+
value = list(
199+
pre = NULL,
200+
post = NULL,
201+
func = c(fun = "predict"),
202+
args =
203+
list(
204+
object = quote(object$fit),
205+
newdata = quote(new_data)
206+
)
207+
)
208+
)
209+
set_pred(
210+
model = "svm_rbf",
211+
eng = "liquidSVM",
212+
mode = "regression",
213+
type = "raw",
214+
value = list(
215+
pre = NULL,
216+
post = NULL,
217+
func = c(fun = "predict"),
218+
args = list(
219+
object = quote(object$fit),
220+
newdata = quote(new_data))
221+
)
222+
)
223+
set_pred(
224+
model = "svm_rbf",
225+
eng = "liquidSVM",
226+
mode = "classification",
227+
type = "class",
228+
value = list(
229+
pre = NULL,
230+
post = NULL,
231+
func = c(fun = "predict"),
232+
args =
233+
list(
234+
object = quote(object$fit),
235+
newdata = quote(new_data)
236+
)
237+
)
238+
)
239+
set_pred(
240+
model = "svm_rbf",
241+
eng = "liquidSVM",
242+
mode = "classification",
243+
type = "prob",
244+
value = list(
245+
pre = function(x, object) {
246+
if (object$fit$predict.prob == FALSE)
247+
stop("`svm` model does not appear to use class probabilities. Was ",
248+
"the model fit with `predict.prob = TRUE`?", call. = FALSE)
249+
x
250+
},
251+
post = function(result, object) {
252+
res <- tibble::as_tibble(result)
253+
names(res) <- object$lvl
254+
res
255+
},
256+
func = c(fun = "predict"),
257+
args =
258+
list(
259+
object = quote(object$fit),
260+
newdata = quote(new_data),
261+
predict.prob = TRUE
262+
)
263+
)
264+
)
265+
set_pred(
266+
model = "svm_rbf",
267+
eng = "liquidSVM",
268+
mode = "classification",
269+
type = "raw",
270+
value = list(
271+
pre = NULL,
272+
post = NULL,
273+
func = c(fun = "predict"),
274+
args = list(
275+
object = quote(object$fit),
276+
newdata = quote(new_data))
277+
)
278+
)

tests/testthat/test_svm_liquidsvm.R

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
library(testthat)
2+
library(parsnip)
3+
library(rlang)
4+
library(tibble)
5+
6+
# ------------------------------------------------------------------------------
7+
8+
test_that('primary arguments', {
9+
basic <- svm_rbf(mode = "classification")
10+
basic_liquidSVM <- translate(basic %>% set_engine("liquidSVM"))
11+
12+
expect_equal(
13+
object = basic_liquidSVM$method$fit$args,
14+
expected = list(
15+
x = expr(missing_arg()),
16+
y = expr(missing_arg()),
17+
folds = 1,
18+
threads = 0
19+
)
20+
)
21+
22+
rbf_sigma <-
23+
svm_rbf(mode = "classification", rbf_sigma = .2) %>%
24+
set_engine("liquidSVM")
25+
rbf_sigma_liquidSVM <- translate(rbf_sigma)
26+
27+
expect_equal(
28+
object = rbf_sigma_liquidSVM$method$fit$args,
29+
expected = list(
30+
x = expr(missing_arg()),
31+
y = expr(missing_arg()),
32+
gammas = quo(.2),
33+
folds = 1,
34+
threads = 0
35+
)
36+
)
37+
38+
})
39+
40+
test_that('engine arguments', {
41+
42+
liquidSVM_scale <-
43+
svm_rbf() %>%
44+
set_mode("classification") %>%
45+
set_engine("liquidSVM", scale = FALSE, predict.prob = TRUE, threads = 2, gpus = 1)
46+
47+
expect_equal(
48+
object = translate(liquidSVM_scale, "liquidSVM")$method$fit$args,
49+
expected = list(
50+
x = expr(missing_arg()),
51+
y = expr(missing_arg()),
52+
scale = new_quosure(FALSE, env = empty_env()),
53+
predict.prob = new_quosure(TRUE, env = empty_env()),
54+
threads = new_quosure(2, env = empty_env()),
55+
gpus = new_quosure(1, env = empty_env()),
56+
folds = 1
57+
)
58+
)
59+
60+
})
61+
62+
63+
test_that('updating', {
64+
65+
expr1 <- svm_rbf() %>% set_engine("liquidSVM", scale = TRUE)
66+
expr1_exp <- svm_rbf(rbf_sigma = .1) %>% set_engine("liquidSVM", scale = TRUE)
67+
68+
expr3 <- svm_rbf(rbf_sigma = .2) %>% set_engine("liquidSVM")
69+
expr3_exp <- svm_rbf(rbf_sigma = .3) %>% set_engine("liquidSVM")
70+
71+
expect_equal(update(expr1, rbf_sigma = .1), expr1_exp)
72+
expect_equal(update(expr3, rbf_sigma = .3, fresh = TRUE), expr3_exp)
73+
})
74+
75+
test_that('bad input', {
76+
expect_error(svm_rbf(mode = "reallyunknown"))
77+
expect_error(translate(svm_rbf() %>% set_engine( NULL)))
78+
})
79+
80+
# ------------------------------------------------------------------------------
81+
# define model specification for classification and regression
82+
83+
reg_mod <-
84+
svm_rbf(rbf_sigma = .1, cost = 0.25) %>%
85+
set_engine("liquidSVM", random_seed = 1234, folds = 1) %>%
86+
set_mode("regression")
87+
88+
cls_mod <-
89+
svm_rbf(rbf_sigma = .1, cost = 0.125) %>%
90+
set_engine("liquidSVM", random_seed = 1234, folds = 1) %>%
91+
set_mode("classification")
92+
93+
ctrl <- fit_control(verbosity = 0, catch = FALSE)
94+
95+
# ------------------------------------------------------------------------------
96+
97+
test_that('svm rbf regression', {
98+
99+
skip_if_not_installed("liquidSVM")
100+
101+
expect_error(
102+
fit_xy(
103+
reg_mod,
104+
control = ctrl,
105+
x = iris[, 2:4],
106+
y = iris$Sepal.Length
107+
),
108+
regexp = NA
109+
)
110+
111+
expect_error(
112+
fit(
113+
reg_mod,
114+
Sepal.Length ~ .,
115+
data = iris[, -5],
116+
control = ctrl
117+
),
118+
regexp = NA
119+
)
120+
121+
})
122+
123+
124+
test_that('svm rbf regression prediction', {
125+
126+
skip_if_not_installed("liquidSVM")
127+
128+
reg_form <-
129+
fit(
130+
object = reg_mod,
131+
formula = Sepal.Length ~ .,
132+
data = iris[, -5],
133+
control = ctrl
134+
)
135+
136+
reg_xy_form <-
137+
fit_xy(
138+
object = reg_mod,
139+
x = iris[, 2:4],
140+
y = iris$Sepal.Length,
141+
control = ctrl
142+
)
143+
expect_equal(reg_form$spec, reg_xy_form$spec)
144+
145+
liquidSVM_form <-
146+
liquidSVM::svm(
147+
x = Sepal.Length ~ .,
148+
y = iris[, -5],
149+
gammas = .1,
150+
lambdas = 0.25,
151+
folds = 1,
152+
random_seed = 1234
153+
)
154+
155+
liquidSVM_xy_form <-
156+
liquidSVM::svm(
157+
x = iris[, 2:4],
158+
y = iris$Sepal.Length,
159+
gammas = .1,
160+
lambdas = 0.25,
161+
folds = 1,
162+
random_seed = 1234
163+
)
164+
165+
# check coeffs for liquidSVM formula and liquidSVM xy fit interfaces
166+
expect_equal(liquidSVM::getSolution(liquidSVM_form)[c("coeff", "sv")],
167+
liquidSVM::getSolution(liquidSVM_xy_form)[c("coeff", "sv")])
168+
169+
# check predictions for liquidSVM formula and liquidSVM xy interfaces
170+
liquidSVM_form_preds <- predict(liquidSVM_form, iris[1:3, 2:4])
171+
liquidSVM_form_xy_preds <- predict(liquidSVM_xy_form, iris[1:3, 2:4])
172+
expect_equal(liquidSVM_form_preds, liquidSVM_form_xy_preds)
173+
174+
# check predictions for parsnip formula and liquidSVM formula interfaces
175+
liquidSVM_pred <-
176+
structure(
177+
list(.pred = liquidSVM_form_preds),
178+
row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame"))
179+
180+
parsnip_pred <- predict(reg_form, iris[1:3, 2:4])
181+
expect_equal(as.data.frame(liquidSVM_pred), as.data.frame(parsnip_pred))
182+
183+
# check that coeffs are equal for formula methods called via parsnip and liquidSVM
184+
expect_equal(liquidSVM::getSolution(reg_form$fit)[c("coeff", "sv")],
185+
liquidSVM::getSolution(liquidSVM_form)[c("coeff", "sv")])
186+
187+
# check coeffs are equivalent for parsnip fit_xy and parsnip formula methods
188+
expect_equal(liquidSVM::getSolution(reg_form$fit)[c("coeff", "sv")],
189+
liquidSVM::getSolution(reg_xy_form$fit)[c("coeff", "sv")])
190+
191+
# check predictions are equal for parsnip xy and liquidSVM xy methods
192+
parsnip_xy_pred <- predict(reg_xy_form, iris[1:3, -c(1, 5)])
193+
expect_equal(as.data.frame(liquidSVM_pred), as.data.frame(parsnip_xy_pred))
194+
})
195+

0 commit comments

Comments
 (0)