Skip to content

Commit 2f99eca

Browse files
author
Gregory Yauney
committed
abstract out common error functions
1 parent 36a6956 commit 2f99eca

File tree

5 files changed

+51
-53
lines changed

5 files changed

+51
-53
lines changed

benchmarks/blackscholes/error.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1+
import sys
12
import numpy as np
2-
from scipy.special import erf
3+
sys.path.append('../..')
4+
from error_utils import *
35

46
names_and_norms = [('l0', 0), ('l1', 1), ('l2', 2)]
57
variances = [1, 100, 1000, 10000, 100000]
68
names_and_args = [('%s_%d' % (name, variance), norm, variance)
79
for name, norm in names_and_norms for variance in variances]
810
error_names = [name for name, _, _ in names_and_args]
911

10-
def error_function(perforated, variance):
11-
return erf(abs(perforated)/variance)
12-
13-
def vector_from_file(fn):
14-
with open(fn) as f:
15-
f.readline() # skip the count
16-
return np.asarray([float(line) for line in f if line.strip()])
17-
1812
def error(standard_fn, perforated_fn):
19-
standard = vector_from_file(standard_fn)
20-
perforated = vector_from_file(perforated_fn)
13+
standard = get_vector(standard_fn)
14+
perforated = get_vector(perforated_fn)
2115

2216
results = {name:
23-
error_function(np.linalg.norm(standard - perforated, ord=norm), variance)
17+
norm_and_error_function(standard, perforated, norm, variance)
2418
for name, norm, variance in names_and_args}
2519

2620
# any nan incurs the max error

benchmarks/img-blur/error.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
1-
import PIL
2-
import PIL.Image
3-
import numpy as np
4-
from scipy.special import erf
1+
import sys
2+
sys.path.append('../..')
3+
from error_utils import *
54

65
names_and_norms = [('l0', 0), ('l1', 1), ('l2', 2)]
76
variances = [1000, 10000, 100000]
87
names_and_args = [('%s_%d' % (name, variance), norm, variance)
98
for name, norm in names_and_norms for variance in variances]
109
error_names = [name for name, _, _ in names_and_args]
1110

12-
def error_function(perforated, variance):
13-
return erf(abs(perforated)/variance)
14-
1511
def error(standard_fn, perforated_fn):
16-
standard_img = PIL.Image.open(standard_fn)
17-
perforated_img = PIL.Image.open(perforated_fn)
1812

1913
try:
20-
standard = np.asarray(standard_img.getdata())
21-
perforated = np.asarray(perforated_img.getdata())
14+
standard = get_image(standard_fn)
15+
perforated = get_image(perforated_fn)
2216
except ValueError:
2317
return {name: 1.0 for name in error_names}
2418

2519
return {name:
26-
error_function(np.linalg.norm(standard - perforated, ord=norm), variance)
20+
norm_and_error_function(standard, perforated, norm, variance)
2721
for name, norm, variance in names_and_args}

benchmarks/sobel/error.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
1-
import PIL
2-
import PIL.Image
3-
import numpy as np
4-
from scipy.special import erf
1+
import sys
2+
sys.path.append('../..')
3+
from error_utils import *
54

65
names_and_norms = [('l0', 0), ('l1', 1), ('l2', 2)]
76
variances = [1000, 10000, 100000]
87
names_and_args = [('%s_%d' % (name, variance), norm, variance)
98
for name, norm in names_and_norms for variance in variances]
109
error_names = [name for name, _, _ in names_and_args]
1110

12-
def error_function(perforated, variance):
13-
return erf(abs(perforated)/variance)
14-
1511
def error(standard_fn, perforated_fn):
16-
standard_img = PIL.Image.open(standard_fn)
17-
perforated_img = PIL.Image.open(perforated_fn)
1812

1913
try:
20-
standard = np.asarray(standard_img.getdata())
21-
perforated = np.asarray(perforated_img.getdata())
14+
standard = get_image(standard_fn)
15+
perforated = get_image(perforated_fn)
2216
except ValueError:
2317
return {name: 1.0 for name in error_names}
2418

2519
return {name:
26-
error_function(np.linalg.norm(standard - perforated, ord=norm), variance)
20+
norm_and_error_function(standard, perforated, norm, variance)
2721
for name, norm, variance in names_and_args}

error_utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import sys
2+
from scipy.special import erf
3+
import numpy as np
4+
import PIL
5+
import PIL.Image
6+
7+
def error_function(standard, perforated, variance):
8+
return erf(abs(standard - perforated)/variance)
9+
10+
def norm_and_error_function(standard, perforated, norm, variance):
11+
return error_function(0, np.linalg.norm(standard - perforated, ord=norm), variance)
12+
13+
def string_to_matrix(s):
14+
# rows separated by '\n'
15+
# columns separated by ' '
16+
return np.asarray([[float(e) for e in l.split()] for l in s.strip().split('\n')])
17+
18+
def get_contents(fn):
19+
with open(fn, 'r') as f:
20+
return f.read()
21+
22+
def get_vector(fn):
23+
with open(fn) as f:
24+
f.readline() # skip the count
25+
return np.asarray([float(line) for line in f if line.strip()])
26+
27+
def get_image(fn):
28+
return np.asarray(PIL.Image.open(fn).getdata())

tests/matrix_multiply/error.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,25 @@
11
import sys
2-
import numpy as np
3-
from scipy.special import erf
2+
sys.path.append('../..')
3+
from error_utils import *
44

55
names_and_norms = [('l2', 2), ('froebenius', 'fro')]
66
variances = [1, 10, 100]
77
names_and_args = [('%s_%d' % (name, variance), norm, variance)
88
for name, norm in names_and_norms for variance in variances]
99
error_names = [name for name, _, _ in names_and_args]
1010

11-
def string_to_matrix(s):
12-
return np.asarray([[float(e) for e in l.split()] for l in s.strip().split('\n')])
13-
14-
def error_function(perforated, variance):
15-
return erf(abs(perforated)/variance)
16-
1711
def error(standard_fn, perforated_fn):
18-
standard = get_contents(standard_fn)
19-
perforated = get_contents(perforated_fn)
20-
standard = string_to_matrix(standard)
21-
perforated = string_to_matrix(perforated)
12+
standard = string_to_matrix(get_contents(standard_fn))
13+
perforated = string_to_matrix(get_contents(perforated_fn))
2214

2315
# max error if sizes differ
2416
if standard.shape != perforated.shape:
2517
return {name: 1.0 for name in error_names}
2618

2719
return {name:
28-
error_function(np.linalg.norm(standard - perforated, ord=norm), variance)
20+
norm_and_error_function(standard, perforated, norm, variance)
2921
for name, norm, variance in names_and_args}
3022

31-
def get_contents(fn):
32-
with open(fn, 'r') as f:
33-
return f.read()
34-
3523
def main():
3624
standard_fn = sys.argv[1] if len(sys.argv) > 2 else 'standard.txt'
3725
perforated_fn = sys.argv[2] if len(sys.argv) > 2 else 'perforated.txt'

0 commit comments

Comments
 (0)