Skip to content

Commit b5cf7d7

Browse files
committed
first pass at spark execution support
1 parent 0d1a503 commit b5cf7d7

File tree

4 files changed

+98
-76
lines changed

4 files changed

+98
-76
lines changed

R/fit.R

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ fit_formula <- function(object, formula, data, engine = engine, control, ...) {
151151
env = current_env()
152152
)
153153
} else {
154-
if(object$method$interface %in% c("data.frame", "matrix")) {
154+
if(object$method$interface %in% c("data.frame", "matrix", "spark")) {
155155
res <- formula_to_xy(object = object, formula = formula, data = data, control)
156156
} else {
157157
stop("I don't know about the ",
@@ -162,31 +162,54 @@ fit_formula <- function(object, formula, data, engine = engine, control, ...) {
162162
res
163163
}
164164

165+
166+
xy_to_xy <- function(object, x, y, control, ...) {
167+
fit_expr <- object$method$fit_call
168+
fit_expr[["x"]] <- quote(x)
169+
fit_expr[["y"]] <- quote(y)
170+
eval_mod(
171+
fit_expr,
172+
capture = control$verbosity == 0,
173+
catch = control$catch,
174+
env = current_env()
175+
)
176+
}
177+
xy_to_matrix <- function(object, x, y, control, ...) {
178+
if (object$method$interface == "matrix" && !is.matrix(x))
179+
x <- as.matrix(x)
180+
xy_to_xy(object, x, y, control, ...)
181+
}
182+
xy_to_df <- function(object, x, y, control, ...) {
183+
if (object$method$interface == "data.frame" && !is.data.frame(x))
184+
x <- as.data.frame(x)
185+
xy_to_xy(object, x, y, control, ...)
186+
}
187+
xy_to_spark <- function(object, x, y, control, ...) {
188+
sdf <- sparklyr::sdf_bind_cols(x, y)
189+
fit_expr <- object$method$fit_call
190+
fit_expr[["x"]] <- quote(sdf)
191+
fit_expr[["features_col"]] <- quote(colnames(x))
192+
fit_expr[["label_col"]] <- quote(colnames(y))
193+
eval_mod(
194+
fit_expr,
195+
capture = control$verbosity == 0,
196+
catch = control$catch,
197+
env = current_env()
198+
)
199+
}
200+
201+
165202
fit_xy <- function(object, x, y, control, ...) {
166203
opts <- quos(...)
167204

168-
# Look up the model's interface (e.g. formula, recipes, etc)
169-
# and delegate to the connector functions (`xy_to_formula` etc)
170-
if(object$method$interface == "formula") {
171-
res <- xy_to_formula(object = object, x = x, y = y, control)
172-
} else {
173-
if(object$method$interface %in% c("data.frame", "matrix")) {
174-
fit_expr <- object$method$fit_call
175-
fit_expr[["x"]] <- quote(x)
176-
fit_expr[["y"]] <- quote(y)
177-
res <-
178-
eval_mod(
179-
fit_expr,
180-
capture = control$verbosity == 0,
181-
catch = control$catch,
182-
env = current_env()
183-
)
184-
} else {
185-
stop("I don't know about the ",
186-
object$method$interface, " interface.",
187-
call. = FALSE)
188-
}
189-
}
205+
res <- switch(
206+
object$method$interface,
207+
formula = xy_to_formula(object = object, x = x, y = y, control, ...),
208+
matrix = xy_to_matrix(object = object, x = x, y = y, control, ...),
209+
data.frame = xy_to_df(object = object, x = x, y = y, control, ...),
210+
spark = xy_to_spark(object = object, x = x, y = y, control, ...),
211+
stop("Unknown interface")
212+
)
190213
res
191214
}
192215

@@ -394,11 +417,11 @@ has_both_or_none <- function(a, b)
394417
check_interface <- function(formula, recipe, x, y, data, cl) {
395418
inher(formula, "formula", cl)
396419
inher(recipe, "recipe", cl)
397-
inher(x, c("data.frame", "matrix"), cl)
420+
inher(x, c("data.frame", "matrix", "tbl_spark"), cl)
398421
# `y` can be a vector (which is not a class), or a factor (which is not a vector)
399422
if(!is.null(y) && !is.vector(y))
400-
inher(y, c("data.frame", "matrix", "factor"), cl)
401-
inher(data, c("data.frame", "matrix"), cl)
423+
inher(y, c("data.frame", "matrix", "factor", "tbl_spark"), cl)
424+
inher(data, c("data.frame", "matrix", "tbl_spark"), cl)
402425

403426
x_interface <- !is.null(x) & !is.null(y)
404427
rec_interface <- !is.null(recipe) & !is.null(data)

R/logistic_reg_constr.R

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ logistic_reg_engines <- data.frame(
1414
glm = TRUE,
1515
glmnet = TRUE,
1616
spark = TRUE,
17-
stan = TRUE,
17+
stan = TRUE,
1818
row.names = c("classification")
1919
)
2020

2121
###################################################################
2222

2323
#' @importFrom stats binomial
24-
logistic_reg_glm_constr <-
24+
logistic_reg_glm_constr <-
2525
function(
2626
formula = missing_arg(),
2727
family = binomial,
@@ -42,10 +42,10 @@ logistic_reg_glm_constr <-
4242
) {
4343
libs <- "stats"
4444
interface <- "formula"
45-
protect = c("formula", "data", "weights")
45+
protect = c("formula", "data", "weights")
4646
has_dots <- TRUE
4747
fit_name <- "glm"
48-
fit_args <-
48+
fit_args <-
4949
enexprs(
5050
formula = formula,
5151
family = family,
@@ -64,7 +64,7 @@ logistic_reg_glm_constr <-
6464
y = y,
6565
contrasts = contrasts
6666
)
67-
res <-
67+
res <-
6868
list(
6969
library = libs,
7070
interface = interface,
@@ -78,7 +78,7 @@ logistic_reg_glm_constr <-
7878
res
7979
}
8080

81-
logistic_reg_glmnet_constr <-
81+
logistic_reg_glmnet_constr <-
8282
function(
8383
x = as.matrix(x),
8484
y = missing_arg(),
@@ -105,11 +105,11 @@ logistic_reg_glmnet_constr <-
105105
type.multinomial = c("ungrouped", "grouped")
106106
) {
107107
libs <- "glmnet"
108-
interface <- "data.frame"
109-
protect = c("x", "y", "weights", "family")
108+
interface <- "matrix"
109+
protect = c("x", "y", "weights", "family")
110110
has_dots <- FALSE
111111
fit_name <- "glmnet"
112-
fit_args <-
112+
fit_args <-
113113
enexprs(
114114
x = x,
115115
y = y,
@@ -135,7 +135,7 @@ logistic_reg_glmnet_constr <-
135135
standardize.response = standardize.response,
136136
type.multinomial = type.multinomial
137137
)
138-
res <-
138+
res <-
139139
list(
140140
library = libs,
141141
interface = interface,
@@ -149,7 +149,7 @@ logistic_reg_glmnet_constr <-
149149
res
150150
}
151151

152-
logistic_reg_stan_constr <-
152+
logistic_reg_stan_constr <-
153153
function(
154154
formula = missing_arg(),
155155
family = binomial(),
@@ -173,10 +173,10 @@ logistic_reg_stan_constr <-
173173
) {
174174
libs <- "rstanarm"
175175
interface <- "formula"
176-
protect = c("formula", "data", "weights")
176+
protect = c("formula", "data", "weights")
177177
has_dots <- TRUE
178178
fit_name <- "stan_glm"
179-
fit_args <-
179+
fit_args <-
180180
enexprs(
181181
formula = formula,
182182
family = family,
@@ -198,7 +198,7 @@ logistic_reg_stan_constr <-
198198
QR = QR,
199199
sparse = sparse
200200
)
201-
res <-
201+
res <-
202202
list(
203203
library = libs,
204204
interface = interface,
@@ -213,7 +213,7 @@ logistic_reg_stan_constr <-
213213
}
214214

215215

216-
logistic_reg_spark_constr <-
216+
logistic_reg_spark_constr <-
217217
function(
218218
x = missing_arg(),
219219
formula = NULL,
@@ -239,11 +239,11 @@ logistic_reg_spark_constr <-
239239
uid = random_string("logistic_regression_")
240240
) {
241241
libs <- "sparklyr"
242-
interface <- "formula"
243-
protect = c("features_col", "label_col", "x", "weight_col")
242+
interface <- "spark"
243+
protect = c("features_col", "label_col", "x", "weight_col")
244244
has_dots <- TRUE
245245
fit_name <- "ml_logistic_regression"
246-
fit_args <-
246+
fit_args <-
247247
enexprs(
248248
x = x,
249249
formula = formula,
@@ -268,7 +268,7 @@ logistic_reg_spark_constr <-
268268
raw_prediction_col = raw_prediction_col,
269269
uid = uid
270270
)
271-
res <-
271+
res <-
272272
list(
273273
library = libs,
274274
interface = interface,

R/rand_forest_constr.R

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
rand_forest_arg_key <- data.frame(
33
randomForest = c("mtry", "ntree", "nodesize"),
44
ranger = c("mtry", "num.trees", "min.node.size"),
5-
spark =
5+
spark =
66
c("feature_subset_strategy", "num_trees", "min_instances_per_node"),
77
stringsAsFactors = FALSE,
88
row.names = c("mtry", "trees", "min_n")
@@ -19,7 +19,7 @@ rand_forest_engines <- data.frame(
1919

2020
###################################################################
2121

22-
rand_forest_ranger_constr <-
22+
rand_forest_ranger_constr <-
2323
function(
2424
formula = missing_arg(),
2525
data = missing_arg(),
@@ -52,10 +52,10 @@ rand_forest_ranger_constr <-
5252
) {
5353
libs <- "ranger"
5454
interface <- "formula"
55-
protect = c("formula", "data", "case.weights")
55+
protect = c("formula", "data", "case.weights")
5656
has_dots <- FALSE
5757
fit_name <- "ranger"
58-
fit_args <-
58+
fit_args <-
5959
enexprs(
6060
formula = formula,
6161
data = data,
@@ -86,7 +86,7 @@ rand_forest_ranger_constr <-
8686
status.variable.name = status.variable.name,
8787
classification = classification
8888
)
89-
res <-
89+
res <-
9090
list(
9191
library = libs,
9292
interface = interface,
@@ -100,13 +100,13 @@ rand_forest_ranger_constr <-
100100
res
101101
}
102102

103-
rand_forest_randomForest_constr <-
103+
rand_forest_randomForest_constr <-
104104
function(
105-
x = missing_arg(),
106-
y = missing_arg(),
107-
xtest = NULL,
108-
ytest = NULL,
109-
ntree = 500,
105+
x = missing_arg(),
106+
y = missing_arg(),
107+
xtest = NULL,
108+
ytest = NULL,
109+
ntree = 500,
110110
mtry = if (!is.null(y) && !is.factor(y))
111111
max(floor(ncol(x) / 3), 1)
112112
else
@@ -134,7 +134,7 @@ rand_forest_randomForest_constr <-
134134
keep.forest = !is.null(y) && is.null(xtest),
135135
corr.bias = FALSE,
136136
keep.inbag = FALSE
137-
)
137+
)
138138
{
139139
libs <- "randomForest"
140140
interface <- "data.frame"
@@ -166,8 +166,8 @@ rand_forest_randomForest_constr <-
166166
keep.forest = keep.forest,
167167
corr.bias = corr.bias,
168168
keep.inbag = keep.inbag
169-
)
170-
res <-
169+
)
170+
res <-
171171
list(
172172
library = libs,
173173
interface = interface,
@@ -178,11 +178,11 @@ rand_forest_randomForest_constr <-
178178
fit_call = NULL
179179
)
180180
class(res) <- c("rand_forest_constr")
181-
res
181+
res
182182
}
183183

184184

185-
rand_forest_spark_constr <-
185+
rand_forest_spark_constr <-
186186
function(
187187
x = missing_arg(),
188188
formula = NULL,
@@ -211,7 +211,7 @@ rand_forest_spark_constr <-
211211
)
212212
{
213213
libs <- "sparklyr"
214-
interface <- "tbl_spark"
214+
interface <- "spark"
215215
protect = c("x", "features_col", "label_col", "type")
216216
has_dots <- TRUE
217217
fit_name <- "ml_random_forest"
@@ -241,8 +241,8 @@ rand_forest_spark_constr <-
241241
uid = uid,
242242
response = response,
243243
features = features
244-
)
245-
res <-
244+
)
245+
res <-
246246
list(
247247
library = libs,
248248
interface = interface,
@@ -253,7 +253,7 @@ rand_forest_spark_constr <-
253253
fit_call = NULL
254254
)
255255
class(res) <- c("rand_forest_constr")
256-
res
256+
res
257257
}
258258

259259
###################################################################

tests/testthat/test_logistic_reg.R

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,16 @@ test_that('glmnet execution', {
315315
# regexp = NA
316316
# )
317317

318-
# fails during R CMD check but works outside of that
319-
# expect_error(
320-
# fit(
321-
# lc_basic,
322-
# engine = "glmnet",
323-
# control = ctrl,
324-
# x = lending_club[, num_pred],
325-
# y = lending_club$Class
326-
# ),
327-
# regexp = NA
328-
# )
318+
expect_error(
319+
fit(
320+
lc_basic,
321+
engine = "glmnet",
322+
control = ctrl,
323+
x = lending_club[, num_pred],
324+
y = lending_club$Class
325+
),
326+
regexp = NA
327+
)
329328

330329
# fails because `glment` requires a matrix
331330
# expect_error(

0 commit comments

Comments
 (0)