Skip to content

Commit f52d168

Browse files
authored
Merge pull request #23 from davidrsch/tests-improvement
Improving tests
2 parents 0944b0f + b058d3e commit f52d168

File tree

3 files changed

+370
-0
lines changed

3 files changed

+370
-0
lines changed
File renamed without changes.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Mock get_keras_object to isolate the logic of collect_compile_args
2+
mock_get_keras_object <- function(name, type, ...) {
3+
# Return a simple string representation for testing purposes
4+
paste0("mocked_", type, "_", name)
5+
}
6+
7+
# Mock optimizer to avoid keras dependency
8+
mock_optimizer_adam <- function(...) {
9+
"mocked_optimizer_adam"
10+
}
11+
12+
test_that("collect_compile_args handles single-output cases correctly", {
13+
# Mock the keras3::optimizer_adam function
14+
testthat::with_mocked_bindings(
15+
.env = as.environment("package:kerasnip"),
16+
get_keras_object = mock_get_keras_object,
17+
{
18+
# Case 1: Single output, non-character loss and metrics
19+
dummy_loss_obj <- structure(list(), class = "dummy_loss")
20+
dummy_metric_obj <- structure(list(), class = "dummy_metric")
21+
22+
args <- collect_compile_args(
23+
all_args = list(
24+
compile_loss = dummy_loss_obj,
25+
compile_metrics = list(dummy_metric_obj)
26+
),
27+
learn_rate = 0.01,
28+
default_loss = "mse",
29+
default_metrics = "mae"
30+
)
31+
expect_equal(args$loss, dummy_loss_obj)
32+
expect_equal(args$metrics, list(dummy_metric_obj))
33+
}
34+
)
35+
})
36+
37+
test_that("collect_compile_args handles multi-output cases correctly", {
38+
testthat::with_mocked_bindings(
39+
.env = as.environment("package:kerasnip"),
40+
get_keras_object = mock_get_keras_object,
41+
{
42+
# Case 2: Multi-output, single string for loss and metrics
43+
args <- collect_compile_args(
44+
all_args = list(
45+
compile_loss = "categorical_crossentropy",
46+
compile_metrics = "accuracy"
47+
),
48+
learn_rate = 0.01,
49+
default_loss = list(out1 = "mse", out2 = "mae"),
50+
default_metrics = list(out1 = "mse", out2 = "mae")
51+
)
52+
expect_equal(args$loss, "mocked_loss_categorical_crossentropy")
53+
expect_equal(args$metrics, "mocked_metric_accuracy")
54+
55+
# Case 3: Multi-output, named list with mixed types
56+
dummy_loss_obj_2 <- structure(list(), class = "dummy_loss_2")
57+
args_mixed <- collect_compile_args(
58+
all_args = list(
59+
compile_loss = list(out1 = "mae", out2 = dummy_loss_obj_2)
60+
),
61+
learn_rate = 0.01,
62+
default_loss = list(out1 = "mse", out2 = "mae"),
63+
default_metrics = list(out1 = "mse", out2 = "mae")
64+
)
65+
expect_equal(args_mixed$loss$out1, "mocked_loss_mae")
66+
expect_equal(args_mixed$loss$out2, dummy_loss_obj_2)
67+
}
68+
)
69+
})
70+
71+
test_that("collect_compile_args handles named list of metrics (multi-output) correctly", {
72+
testthat::with_mocked_bindings(
73+
.env = as.environment("package:kerasnip"),
74+
get_keras_object = mock_get_keras_object,
75+
{
76+
# Test case: Named list of metrics with mixed types (character and object)
77+
dummy_metric_obj_3 <- structure(list(), class = "dummy_metric_3")
78+
args_mixed_metrics <- collect_compile_args(
79+
all_args = list(
80+
compile_metrics = list(out1 = "accuracy", out2 = dummy_metric_obj_3)
81+
),
82+
learn_rate = 0.01,
83+
default_loss = list(out1 = "mse", out2 = "mae"),
84+
default_metrics = list(out1 = "mse", out2 = "mae") # Important: default_metrics must be a named list for this path
85+
)
86+
expect_equal(args_mixed_metrics$metrics$out1, "mocked_metric_accuracy")
87+
expect_equal(args_mixed_metrics$metrics$out2, dummy_metric_obj_3)
88+
89+
# Test case: Named list of metrics with all characters
90+
args_all_char_metrics <- collect_compile_args(
91+
all_args = list(
92+
compile_metrics = list(out1 = "accuracy", out2 = "mse")
93+
),
94+
learn_rate = 0.01,
95+
default_loss = list(out1 = "mse", out2 = "mae"),
96+
default_metrics = list(out1 = "mse", out2 = "mae")
97+
)
98+
expect_equal(args_all_char_metrics$metrics$out1, "mocked_metric_accuracy")
99+
expect_equal(args_all_char_metrics$metrics$out2, "mocked_metric_mse")
100+
101+
# Test case: Named list of metrics with all objects
102+
dummy_metric_obj_4 <- structure(list(), class = "dummy_metric_4")
103+
dummy_metric_obj_5 <- structure(list(), class = "dummy_metric_5")
104+
args_all_obj_metrics <- collect_compile_args(
105+
all_args = list(
106+
compile_metrics = list(
107+
out1 = dummy_metric_obj_4,
108+
out2 = dummy_metric_obj_5
109+
)
110+
),
111+
learn_rate = 0.01,
112+
default_loss = list(out1 = "mse", out2 = "mae"),
113+
default_metrics = list(out1 = "mse", out2 = "mae")
114+
)
115+
expect_equal(args_all_obj_metrics$metrics$out1, dummy_metric_obj_4)
116+
expect_equal(args_all_obj_metrics$metrics$out2, dummy_metric_obj_5)
117+
}
118+
)
119+
})
120+
121+
test_that("collect_compile_args throws errors for invalid multi-output args", {
122+
# Case 4: Multi-output, invalid loss argument
123+
expect_error(
124+
collect_compile_args(
125+
all_args = list(compile_loss = list("a", "b")), # Unnamed list
126+
learn_rate = 0.01,
127+
default_loss = list(out1 = "mse", out2 = "mae"),
128+
default_metrics = list(out1 = "mse", out2 = "mae")
129+
),
130+
"For multiple outputs, 'compile_loss' must be a single string or a named list of losses."
131+
)
132+
133+
# Case 5: Multi-output, invalid metrics argument
134+
expect_error(
135+
collect_compile_args(
136+
all_args = list(compile_metrics = list("a", "b")), # Unnamed list
137+
learn_rate = 0.01,
138+
default_loss = list(out1 = "mse", out2 = "mae"),
139+
default_metrics = list(out1 = "mse", out2 = "mae")
140+
),
141+
"For multiple outputs, 'compile_metrics' must be a single string or a named list of metrics."
142+
)
143+
})
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
skip_if_no_keras()
2+
3+
# Mock object for post-processing functions
4+
mock_object_single_output <- list(
5+
fit = list(
6+
lvl = c("setosa", "versicolor", "virginica") # For classification levels
7+
)
8+
)
9+
class(mock_object_single_output) <- "model_fit"
10+
11+
mock_object_multi_output <- list(
12+
fit = list(
13+
lvl = list(
14+
output1 = c("classA", "classB"),
15+
output2 = c("typeX", "typeY", "typeZ")
16+
)
17+
)
18+
)
19+
class(mock_object_multi_output) <- "model_fit"
20+
21+
# --- Tests for keras_postprocess_numeric ---
22+
23+
test_that("keras_postprocess_numeric handles single output (matrix) correctly", {
24+
results <- matrix(c(0.1, 0.2, 0.3), ncol = 1)
25+
processed <- keras_postprocess_numeric(results, mock_object_single_output)
26+
expect_s3_class(processed, "tbl_df")
27+
expect_equal(names(processed), ".pred")
28+
expect_equal(processed$.pred, c(0.1, 0.2, 0.3))
29+
})
30+
31+
test_that("keras_postprocess_numeric handles single output (named list with one element) correctly", {
32+
results <- list(output1 = matrix(c(0.1, 0.2, 0.3), ncol = 1))
33+
names(results) <- "output1"
34+
processed <- keras_postprocess_numeric(results, mock_object_multi_output)
35+
expect_s3_class(processed, "tbl_df")
36+
expect_equal(names(processed), ".pred")
37+
expect_equal(processed$.pred, matrix(c(0.1, 0.2, 0.3), ncol = 1)) # Changed expected
38+
})
39+
40+
41+
test_that("keras_postprocess_numeric handles multi-output (named list) correctly", {
42+
results <- list(
43+
output1 = matrix(c(0.1, 0.2), ncol = 1),
44+
output2 = matrix(c(0.4, 0.5), ncol = 1)
45+
)
46+
names(results) <- c("output1", "output2")
47+
processed <- keras_postprocess_numeric(results, mock_object_multi_output)
48+
expect_s3_class(processed, "tbl_df")
49+
expect_equal(names(processed), c(".pred_output1", ".pred_output2"))
50+
# Change expected values to 1-column matrices
51+
expect_equal(processed$.pred_output1, matrix(c(0.1, 0.2), ncol = 1))
52+
expect_equal(processed$.pred_output2, matrix(c(0.4, 0.5), ncol = 1))
53+
})
54+
55+
# --- Tests for keras_postprocess_probs ---
56+
57+
test_that("keras_postprocess_probs handles single output (matrix) correctly", {
58+
results <- matrix(
59+
c(
60+
0.1,
61+
0.9,
62+
0.0, # Example probabilities for 3 classes
63+
0.2,
64+
0.1,
65+
0.7,
66+
0.3,
67+
0.3,
68+
0.4
69+
),
70+
ncol = 3,
71+
byrow = TRUE
72+
)
73+
processed <- keras_postprocess_probs(results, mock_object_single_output)
74+
expect_s3_class(processed, "tbl_df")
75+
expect_equal(names(processed), c("setosa", "versicolor", "virginica")) # Updated expected names
76+
expect_equal(processed$setosa, c(0.1, 0.2, 0.3)) # Access by correct column name
77+
expect_equal(processed$versicolor, c(0.9, 0.1, 0.3)) # Access by correct column name
78+
expect_equal(processed$virginica, c(0.0, 0.7, 0.4)) # Access by correct column name
79+
})
80+
81+
test_that("keras_postprocess_probs handles multi-output (named list) correctly", {
82+
results <- list(
83+
output1 = matrix(c(0.1, 0.9, 0.2, 0.8), ncol = 2, byrow = TRUE),
84+
output2 = matrix(c(0.3, 0.4, 0.3, 0.5, 0.2, 0.3), ncol = 3, byrow = TRUE)
85+
)
86+
names(results) <- c("output1", "output2")
87+
processed <- keras_postprocess_probs(results, mock_object_multi_output)
88+
expect_s3_class(processed, "tbl_df")
89+
expect_equal(
90+
names(processed),
91+
c(
92+
".pred_output1_classA",
93+
".pred_output1_classB",
94+
".pred_output2_typeX",
95+
".pred_output2_typeY",
96+
".pred_output2_typeZ"
97+
)
98+
)
99+
expect_equal(processed$.pred_output1_classA, c(0.1, 0.2))
100+
expect_equal(processed$.pred_output2_typeX, c(0.3, 0.5))
101+
})
102+
103+
test_that("keras_postprocess_probs handles multi-output with NULL levels fallback", {
104+
results <- list(
105+
output1 = matrix(c(0.1, 0.9, 0.2, 0.8), ncol = 2, byrow = TRUE)
106+
)
107+
names(results) <- "output1"
108+
mock_object_null_lvl <- list(
109+
fit = list(
110+
lvl = list(output1 = NULL) # Simulate NULL levels for this output
111+
)
112+
)
113+
class(mock_object_null_lvl) <- "model_fit"
114+
processed <- keras_postprocess_probs(results, mock_object_null_lvl)
115+
expect_s3_class(processed, "tbl_df")
116+
expect_equal(
117+
names(processed),
118+
c(".pred_output1_class1", ".pred_output1_class2")
119+
)
120+
})
121+
122+
# --- Tests for keras_postprocess_classes ---
123+
124+
test_that("keras_postprocess_classes handles single output (multiclass) correctly", {
125+
results <- matrix(c(0.1, 0.8, 0.1, 0.2, 0.1, 0.7), ncol = 3, byrow = TRUE)
126+
processed <- keras_postprocess_classes(results, mock_object_single_output)
127+
expect_s3_class(processed, "tbl_df")
128+
expect_equal(names(processed), ".pred_class")
129+
expect_equal(
130+
as.character(processed$.pred_class),
131+
c("versicolor", "virginica")
132+
)
133+
expect_true(is.factor(processed$.pred_class))
134+
expect_equal(
135+
levels(processed$.pred_class),
136+
c("setosa", "versicolor", "virginica")
137+
)
138+
})
139+
140+
test_that("keras_postprocess_classes handles single output (binary) correctly", {
141+
results <- matrix(c(0.6, 0.4), ncol = 1) # Changed to single column
142+
mock_object_binary_lvl <- list(
143+
fit = list(
144+
lvl = c("negative", "positive")
145+
)
146+
)
147+
class(mock_object_binary_lvl) <- "model_fit"
148+
processed <- keras_postprocess_classes(results, mock_object_binary_lvl)
149+
expect_s3_class(processed, "tbl_df")
150+
expect_equal(names(processed), ".pred_class")
151+
expect_equal(as.character(processed$.pred_class), c("positive", "negative")) # Changed expected
152+
expect_true(is.factor(processed$.pred_class))
153+
expect_equal(levels(processed$.pred_class), c("negative", "positive"))
154+
})
155+
156+
test_that("keras_postprocess_classes handles multi-output (named list) correctly", {
157+
results <- list(
158+
output1 = matrix(c(0.1, 0.9, 0.2, 0.8), ncol = 2, byrow = TRUE), # Binary
159+
output2 = matrix(c(0.3, 0.4, 0.3, 0.5, 0.2, 0.3), ncol = 3, byrow = TRUE) # Multiclass
160+
)
161+
names(results) <- c("output1", "output2")
162+
processed <- keras_postprocess_classes(results, mock_object_multi_output)
163+
expect_s3_class(processed, "tbl_df")
164+
expect_equal(
165+
names(processed),
166+
c(".pred_class_output1", ".pred_class_output2")
167+
)
168+
expect_equal(
169+
as.character(processed$.pred_class_output1),
170+
c("classB", "classB")
171+
)
172+
expect_equal(as.character(processed$.pred_class_output2), c("typeY", "typeX"))
173+
expect_true(is.factor(processed$.pred_class_output1))
174+
expect_true(is.factor(processed$.pred_class_output2))
175+
expect_equal(levels(processed$.pred_class_output1), c("classA", "classB"))
176+
expect_equal(
177+
levels(processed$.pred_class_output2),
178+
c("typeX", "typeY", "typeZ")
179+
)
180+
})
181+
182+
test_that("keras_postprocess_classes handles multi-output with NULL levels fallback", {
183+
results <- list(
184+
output1 = matrix(c(0.6, 0.4, 0.2, 0.8), ncol = 2, byrow = TRUE) # Binary
185+
)
186+
names(results) <- "output1"
187+
mock_object_null_lvl <- list(
188+
fit = list(
189+
lvl = list(output1 = NULL) # Simulate NULL levels for this output
190+
)
191+
)
192+
class(mock_object_null_lvl) <- "model_fit"
193+
processed <- keras_postprocess_classes(results, mock_object_null_lvl)
194+
expect_s3_class(processed, "tbl_df")
195+
expect_equal(names(processed), c(".pred_class_output1"))
196+
expect_equal(
197+
as.character(processed$.pred_class_output1),
198+
c("class1", "class2")
199+
) # Changed expected
200+
expect_true(is.factor(processed$.pred_class_output1))
201+
expect_equal(levels(processed$.pred_class_output1), c("class1", "class2"))
202+
})
203+
204+
test_that("keras_postprocess_classes handles multi-output (binary, single column) correctly", {
205+
results <- list(
206+
output1 = matrix(c(0.6, 0.4, 0.2, 0.8), ncol = 1, byrow = TRUE) # Single column binary output
207+
)
208+
names(results) <- "output1"
209+
mock_object_multi_output_binary <- list(
210+
fit = list(
211+
lvl = list(output1 = c("negative", "positive")) # Levels for binary output
212+
)
213+
)
214+
class(mock_object_multi_output_binary) <- "model_fit"
215+
processed <- keras_postprocess_classes(
216+
results,
217+
mock_object_multi_output_binary
218+
)
219+
expect_s3_class(processed, "tbl_df")
220+
expect_equal(names(processed), c(".pred_class_output1"))
221+
expect_equal(
222+
as.character(processed$.pred_class_output1),
223+
c("positive", "negative", "negative", "positive")
224+
) # Expected based on 0.5 threshold
225+
expect_true(is.factor(processed$.pred_class_output1))
226+
expect_equal(levels(processed$.pred_class_output1), c("negative", "positive"))
227+
})

0 commit comments

Comments
 (0)