Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
3fb75c2
...
sebffischer Oct 18, 2024
5f0d45a
...
sebffischer Jan 10, 2025
e877385
...
sebffischer Jan 10, 2025
796c840
...
sebffischer Jan 10, 2025
5b1e25d
...
sebffischer Mar 31, 2025
36732e9
Merge branch 'main' into paper2
sebffischer Apr 3, 2025
67b00b9
hack
sebffischer Apr 3, 2025
a6103d1
...
sebffischer Apr 3, 2025
753b796
...
sebffischer Apr 4, 2025
139b353
...
sebffischer Apr 4, 2025
3aecccc
...
sebffischer Apr 4, 2025
966f66e
...
sebffischer Apr 4, 2025
785c311
...
sebffischer Apr 4, 2025
a40f8d4
...
sebffischer Apr 4, 2025
6c9a428
...
sebffischer Apr 4, 2025
169105a
only keep train predictions if they are needed
sebffischer Apr 7, 2025
20aad01
...
sebffischer Apr 8, 2025
1c2327f
...
sebffischer Apr 8, 2025
b807406
...
sebffischer Apr 8, 2025
e694709
...
sebffischer Apr 8, 2025
bce4248
...
sebffischer Apr 8, 2025
827c3b7
...
sebffischer Apr 8, 2025
169d7b5
...
sebffischer Apr 8, 2025
be1749d
...
sebffischer Apr 8, 2025
05c9b8a
...
sebffischer Apr 8, 2025
e1167e0
...
sebffischer Apr 11, 2025
9317b60
...
sebffischer Apr 14, 2025
f7e1215
...
sebffischer Apr 14, 2025
fc64308
...
sebffischer Apr 14, 2025
4cc1753
work on paper
sebffischer Jul 16, 2025
58ce4ab
include Rprofile
sebffischer Jul 16, 2025
7ed471e
add results
sebffischer Jul 18, 2025
1c9568e
Merge branch 'main' into paper2
sebffischer Jul 18, 2025
bebda3d
...
sebffischer Jul 18, 2025
e99bf06
fix conflict
sebffischer Jul 18, 2025
b9f9eb1
benchmark
sebffischer Jul 19, 2025
8206c71
...
sebffischer Jul 19, 2025
f4ac761
...
sebffischer Jul 19, 2025
794d185
...
sebffischer Jul 19, 2025
5ea6945
...
sebffischer Jul 20, 2025
f31890e
...
sebffischer Jul 20, 2025
2b7e4e9
last line
sebffischer Jul 20, 2025
577a04b
...
sebffischer Jul 20, 2025
8d109d9
...
sebffischer Jul 20, 2025
3008c3a
...
sebffischer Jul 21, 2025
b59d09d
...
sebffischer Jul 21, 2025
4b39665
update result
sebffischer Jul 23, 2025
a19be3e
update fiels
sebffischer Jul 24, 2025
9f9983a
...
sebffischer Jul 25, 2025
dd08427
...
sebffischer Jul 25, 2025
ed5cdea
...
sebffischer Jul 25, 2025
781c20f
...
sebffischer Jul 25, 2025
3e846df
...
sebffischer Jul 29, 2025
393e7c0
...
sebffischer Aug 1, 2025
085d0fb
...
sebffischer Aug 1, 2025
9f2c015
...
sebffischer Aug 1, 2025
313cf40
...
sebffischer Aug 1, 2025
e53c0f6
...
sebffischer Aug 1, 2025
3bc79e2
...
sebffischer Aug 1, 2025
fadef47
...
sebffischer Aug 2, 2025
f64978c
...
sebffischer Aug 4, 2025
bccba11
...
sebffischer Aug 4, 2025
331b39e
...
sebffischer Aug 4, 2025
ad5d8cf
...
sebffischer Aug 7, 2025
879637b
...
sebffischer Aug 7, 2025
08acd23
...
sebffischer Aug 7, 2025
f77b258
...
sebffischer Aug 8, 2025
38bf3fb
...
sebffischer Aug 9, 2025
979c9b5
...
sebffischer Aug 9, 2025
242c91d
...
sebffischer Aug 10, 2025
a83106c
...
sebffischer Aug 10, 2025
4b84811
...
sebffischer Aug 10, 2025
a681e84
...
sebffischer Aug 10, 2025
b15f194
...
sebffischer Aug 10, 2025
0481d25
benchmark
sebffischer Aug 14, 2025
723dd7c
...
sebffischer Aug 14, 2025
f9adc0a
overwrite experiment with cpu only on gpu server
sebffischer Aug 15, 2025
4511b74
hope I fixed it for good
sebffischer Aug 15, 2025
7de5b04
run new rocker on gpu cluter
sebffischer Aug 15, 2025
fba71fa
refactor
sebffischer Aug 15, 2025
ba32ffc
...
sebffischer Aug 15, 2025
a196f5e
...
sebffischer Aug 15, 2025
b24778a
...
sebffischer Aug 15, 2025
fd37790
...g
sebffischer Aug 16, 2025
efc3cdf
...
sebffischer Aug 16, 2025
1147997
...
sebffischer Aug 16, 2025
8a160b5
use higher reps, higehr latent
sebffischer Aug 16, 2025
bb5d3f1
...
sebffischer Aug 16, 2025
83d487d
gpu results
sebffischer Aug 19, 2025
b6a97bf
cpu results
sebffischer Aug 19, 2025
dcf74bb
...
sebffischer Aug 20, 2025
93eae67
cpu version cheap
sebffischer Aug 20, 2025
ff174c0
...
sebffischer Aug 20, 2025
f0218d3
...
sebffischer Aug 20, 2025
7a8bd08
cheap version is reproducible
sebffischer Aug 20, 2025
1028e73
...
sebffischer Aug 25, 2025
6b59388
readme
sebffischer Aug 26, 2025
5ce4023
...
sebffischer Aug 26, 2025
99da0b7
...
sebffischer Aug 27, 2025
458cb44
update extracted paper code
sebffischer Aug 27, 2025
579acb5
...
sebffischer Sep 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .github/workflows/r-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
if: runner.os == 'Windows'
id: get_package_version_windows
run: |
$version = Rscript -e 'cat(as.character(packageVersion("torchvision")))'
$version = Rscript -e 'cat(as.character(packageVersion("torchvision")))'
echo "TORCHVISION_PACKAGE_VERSION=$version" >> $env:GITHUB_ENV

- name: Get torch cache path (Linux/macOS)
Expand Down
10 changes: 8 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ mlr3torch*.tgz
*~
docs
inst/doc
*.html
**/.DS_Store
/doc/
/Meta/
CRAN-SUBMISSION
paper/data
.idea/
.vsc/
paper/data
paper/data/
paper/benchmark/registry
.vscode/
paper/benchmark/registry-linux-cpu/
paper/benchmark/registry-macos/
paper/benchmark/registry-linux-gpu/
paper/benchmark/registry-linux-gpu-optimizer/
paper/benchmark/registry-linux-gpu-old/
18 changes: 12 additions & 6 deletions R/learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ learner_torch_train = function(self, private, super, task, param_vals) {
stopf("Training Dataloader of Learner '%s' has length 0", self$id)
}

network = private$.network(task, param_vals)$to(device = param_vals$device)
network = private$.network(task, param_vals)
network$to(device = param_vals$device)
if (isTRUE(param_vals$jit_trace) && !inherits(network, "script_module")) {
example = get_example_batch(loader_train)$x
example = lapply(example, function(x) x$to(device = param_vals$device))
Expand Down Expand Up @@ -134,6 +135,8 @@ train_loop = function(ctx, cbs) {

ctx$network$train()

forward = get_forward(ctx$network)

# if we increment epoch at the end of the loop it has the wrong value
# during the final two callback stages
ctx$epoch = 0L
Expand All @@ -145,6 +148,7 @@ train_loop = function(ctx, cbs) {
indices = list()
train_iterator = dataloader_make_iter(ctx$loader_train)
ctx$step = 0L
eval_train = eval_train_in_epoch(ctx)
while (ctx$step < length(ctx$loader_train)) {
ctx$step = ctx$step + 1
ctx$batch = dataloader_next(train_iterator)
Expand All @@ -155,9 +159,9 @@ train_loop = function(ctx, cbs) {
call("on_batch_begin")

if (length(ctx$batch$x) == 1L) {
ctx$y_hat = ctx$network(ctx$batch$x[[1L]])
ctx$y_hat = forward(ctx$batch$x[[1L]])
} else {
ctx$y_hat = do.call(ctx$network, ctx$batch$x)
ctx$y_hat = do.call(forward, ctx$batch$x)
}

loss = ctx$loss_fn(ctx$y_hat, ctx$batch$y)
Expand All @@ -167,14 +171,16 @@ train_loop = function(ctx, cbs) {
call("on_after_backward")

ctx$last_loss = loss$item()
predictions[[length(predictions) + 1]] = ctx$y_hat$detach()
indices[[length(indices) + 1]] = as.integer(ctx$batch$.index$to(device = "cpu"))
if (eval_train) {
predictions[[length(predictions) + 1]] = ctx$y_hat$detach()
indices[[length(indices) + 1]] = as.integer(ctx$batch$.index$to(device = "cpu"))
}
ctx$optimizer$step()

call("on_batch_end")
}

ctx$last_scores_train = if (eval_train_in_epoch(ctx)) {
ctx$last_scores_train = if (eval_train) {
measure_prediction(
pred_tensor = torch_cat(predictions, dim = 1L),
measures = ctx$measures_train,
Expand Down
6 changes: 1 addition & 5 deletions R/nn.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,5 @@
#' # is the same as:
#' po2 = nn("linear")
nn = function(.key, ...) {
args = list(...)
if (is.null(args$id)) {
args$id = .key
}
invoke(po, .obj = paste0("nn_", .key), .args = args)
invoke(po, .obj = paste0("nn_", .key), id = .key, ...)
}
21 changes: 19 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,23 @@ order_named_args = function(f, l) {
l2
}

get_forward = function(net) {
if (inherits(net, "script_module")) {
is_training = net$is_training
trainforward = net$trainforward
evalforward = net$evalforward
function(...) {
if (is_training()) {
trainforward(...)
} else {
evalforward(...)
}
}
} else {
net$forward
}
}


#' @title Network Output Dimension
#' @description
Expand Down Expand Up @@ -314,7 +331,7 @@ all_or_none_ = function(...) {
single_lazy_tensor = function(task) {
identical(task$feature_types[, "type"][[1L]], "lazy_tensor")
}

n_num_features = function(task) {
sum(task$feature_types$type %in% c("numeric", "integer"))
}
Expand All @@ -325,4 +342,4 @@ n_categ_features = function(task) {

n_ltnsr_features = function(task) {
sum(task$feature_types$type == "lazy_tensor")
}
}
4 changes: 2 additions & 2 deletions man-roxygen/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#' The batch size (required).
#' * `shuffle` :: `logical(1)`\cr
#' Whether to shuffle the instances in the dataset. This is initialized to `TRUE`,
#' which differs from the default (`FALSE`).
#' which differs from the default of the [`torch::dataloader`] which is `FALSE`.
#' * `sampler` :: [`torch::sampler`]\cr
#' Object that defines how the dataloader draw samples.
#' * `batch_sampler` :: [`torch::sampler`]\cr
Expand All @@ -91,4 +91,4 @@
#' * `worker_packages` :: `character()`\cr
#' Which packages to load on the workers.
#'
#' Also see `torch::dataloder` for more information.
#' Also see [`torch::dataloder`] for more information.
75 changes: 75 additions & 0 deletions mlr3torch-benchmark-5274609.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
[INFO] Extracting squashfs filesystem...
Parallel unsquashfs: Using 92 processors
57832 inodes (209999 blocks) to write


created 55975 files
created 6137 directories
created 1735 symlinks
created 0 devices
created 0 fifos
created 0 sockets

==========
== CUDA ==
==========

CUDA Version 12.4.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

WARNING: The NVIDIA Driver was not detected. GPU functionality will not be available.
Use the NVIDIA Container Toolkit to start this container with GPU support; see
https://docs.nvidia.com/datacenter/cloud-native/ .

R version 4.5.0 (2025-04-11)
Platform: x86_64-pc-linux-gnu
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS: /usr/local/lib/R/lib/libRblas.so
LAPACK: /usr/local/lib/R/lib/libRlapack.so; LAPACK version 3.12.1

locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C

time zone: Etc/UTC
tzcode source: system (glibc)

attached base packages:
[1] stats graphics grDevices utils datasets methods base

loaded via a namespace (and not attached):
[1] compiler_4.5.0

Attaching package: ‘mlr3misc’

The following object is masked from ‘package:batchtools’:

chunk

Sourcing configuration file '/mnt/data/mlr3torch/paper/batchtools.conf.R' ...
Loading required package: checkmate
Created registry in '/mnt/data/mlr3torch/paper/benchmark/registry' using cluster functions 'Interactive'
Exporting new objects: 'time_rtorch' ...
Adding problem 'runtime_train'
Adding algorithm 'pytorch'
Adding algorithm 'rtorch'
Adding algorithm 'mlr3torch'
Adding 180 experiments ('runtime_train'[30] x 'rtorch'[2] x repls[3]) ...
Adding 180 experiments ('runtime_train'[30] x 'mlr3torch'[2] x repls[3]) ...
Adding 180 experiments ('runtime_train'[30] x 'pytorch'[2] x repls[3]) ...
Adding 180 experiments ('runtime_train'[30] x 'rtorch'[2] x repls[3]) ...
Adding 180 experiments ('runtime_train'[30] x 'mlr3torch'[2] x repls[3]) ...
Adding 180 experiments ('runtime_train'[30] x 'pytorch'[2] x repls[3]) ...
Loading