diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index c0adb995d..9afad5e38 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -8,11 +8,20 @@ using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, CUDA.allowscalar(false) # ## Loading Datasets -function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset} - imgs, labels = dset(:train)[1:n_train] +function load_dataset(::Type{dset}, n_train::Union{Nothing, Int}, + n_eval::Union{Nothing, Int}, batchsize::Int) where {dset} + if n_train === nothing + imgs, labels = dset(:train) + else + imgs, labels = dset(:train)[1:n_train] + end x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9) - imgs, labels = dset(:test)[1:n_eval] + if n_eval === nothing + imgs, labels = dset(:test) + else + imgs, labels = dset(:test)[1:n_eval] + end x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9) return ( @@ -21,7 +30,9 @@ function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) w ) end -function load_datasets(n_train=1024, n_eval=32, batchsize=256) +function load_datasets(batchsize=256) + n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing + n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize) end diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 8d198b6e3..62663d8bd 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -18,10 +18,15 @@ CUDA.allowscalar(false) # ## Loading MNIST function loadmnist(batchsize, train_split) ## Load MNIST: Only 1500 for demonstration purposes - N = 1500 + N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing dataset = MNIST(; split=:train) - imgs = dataset.features[:, :, 1:N] - labels_raw = dataset.targets[1:N] + if N !== nothing + imgs = dataset.features[:, :, 1:N] + labels_raw = dataset.targets[1:N] + else + imgs = dataset.features + labels_raw = dataset.targets + end ## Process images into (H,W,C,BS) batches x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 672680059..1ff12bc23 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -14,10 +14,15 @@ using SimpleChains: SimpleChains # ## Loading MNIST function loadmnist(batchsize, train_split) ## Load MNIST - N = 2000 + N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing dataset = MNIST(; split=:train) - imgs = dataset.features[:, :, 1:N] - labels_raw = dataset.targets[1:N] + if N !== nothing + imgs = dataset.features[:, :, 1:N] + labels_raw = dataset.targets[1:N] + else + imgs = dataset.features + labels_raw = dataset.targets + end ## Process images into (H, W, C, BS) batches x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))