Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 752d908

Browse files
committed
Add dataset loaders for news20, usps and mnist.
1 parent 1124ffa commit 752d908

File tree

4 files changed

+89
-0
lines changed

4 files changed

+89
-0
lines changed

Makefile

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
2+
DATADIR=$(HOME)/lightning_data
3+
4+
datadir:
5+
mkdir -p $(DATADIR)
6+
7+
download-news20: datadir
8+
./download.sh http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.scale.bz2
9+
bunzip2 news20.scale.bz2
10+
mv news20.scale $(DATADIR)
11+
./download.sh http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.scale.bz2
12+
bunzip2 news20.t.scale.bz2
13+
mv news20.t.scale $(DATADIR)
14+
15+
download-usps: datadir
16+
./download.sh http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2
17+
bunzip2 usps.bz2
18+
mv usps $(DATADIR)
19+
./download.sh http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2
20+
bunzip2 usps.t.bz2
21+
mv usps.t $(DATADIR)
22+
23+
download-mnist: datadir
24+
./download.sh http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.scale.bz2
25+
bunzip2 mnist.scale.bz2
26+
mv mnist.scale $(DATADIR)
27+
./download.sh http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.scale.t.bz2
28+
bunzip2 mnist.scale.t.bz2
29+
mv mnist.scale.t $(DATADIR)

download.sh

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
3+
# download a file with either wget or curl
4+
5+
if [ "$1foo" = "foo" ]; then
6+
echo "usage: `basename $0` url"
7+
exit 1
8+
fi
9+
10+
wget_path=`which wget`
11+
12+
if [ $? = 0 ]; then
13+
$wget_path $1
14+
exit $?
15+
fi
16+
17+
curl_path=`which curl`
18+
19+
if [ $? = 0 ]; then
20+
$curl_path -O $1
21+
exit $?
22+
fi
23+

lightning/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .loaders import load_news20, load_usps, load_mnist

lightning/datasets/loaders.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
3+
try:
4+
from svmlight_loader import load_svmlight_files
5+
except ImportError:
6+
from sklearn.datasets import load_svmlight_files
7+
8+
from sklearn.datasets.base import get_data_home as _get_data_home
9+
10+
def get_data_home():
11+
return _get_data_home().replace("scikit_learn", "lightning")
12+
13+
def _load(train_file, test_file, name):
14+
if not os.path.exists(train_file) or not os.path.exists(test_file):
15+
raise IOError("Dataset missing! " +
16+
"Run 'make download-%s' at the project root." % name)
17+
18+
return load_svmlight_files((train_file, test_file))
19+
20+
def load_news20():
21+
data_home = get_data_home()
22+
train_file = os.path.join(data_home, "news20.scale")
23+
test_file = os.path.join(data_home, "news20.t.scale")
24+
return _load(train_file, test_file, "news20")
25+
26+
def load_usps():
27+
data_home = get_data_home()
28+
train_file = os.path.join(data_home, "usps")
29+
test_file = os.path.join(data_home, "usps.t")
30+
return _load(train_file, test_file, "usps")
31+
32+
def load_mnist():
33+
data_home = get_data_home()
34+
train_file = os.path.join(data_home, "mnist.scale")
35+
test_file = os.path.join(data_home, "mnist.scale.t")
36+
return _load(train_file, test_file, "mnist")

0 commit comments

Comments
 (0)