Skip to content

Commit e044edb

Browse files
committed
Merge branch 'joshs-working-branch' into develop
2 parents 6f3f8d9 + a9629e4 commit e044edb

27 files changed

+1968
-1557
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import os
2+
from functools import partial
3+
4+
os.environ['JAX_PLATFORMS'] = 'cpu'
5+
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
6+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
7+
from jax._src.partition_spec import PartitionSpec
8+
from jax.experimental.shard_map import shard_map
9+
10+
from dsa2000_common.common.jax_utils import create_mesh
11+
from dsa2000_common.common.jvp_linear_op import JVPLinearOp
12+
from dsa2000_common.common.logging import dsa_logger
13+
from dsa2000_common.common.mixed_precision_utils import mp_policy
14+
import itertools
15+
import time
16+
from typing import Dict, Any
17+
18+
import jax
19+
import numpy as np
20+
21+
from dsa2000_cal.ops.residuals import compute_residual_BTC
22+
23+
24+
def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
25+
num_antennas = 2048
26+
baseline_pairs = np.asarray(list(itertools.combinations(range(num_antennas), 2)),
27+
dtype=np.int32)
28+
antenna1 = baseline_pairs[:, 0]
29+
antenna2 = baseline_pairs[:, 1] # [B]
30+
31+
sort_idxs = np.lexsort((antenna1, antenna2))
32+
antenna1 = antenna1[sort_idxs]
33+
antenna2 = antenna2[sort_idxs]
34+
35+
B = antenna1.shape[0]
36+
vis_model = np.zeros((D, B, Tm, Cm, 2, 2), dtype=mp_policy.vis_dtype)
37+
vis_data = np.zeros((B, Ts, Cs, 2, 2), dtype=mp_policy.vis_dtype)
38+
gains = np.zeros((D, num_antennas, Tm, Cm, 2, 2), dtype=mp_policy.gain_dtype)
39+
return dict(
40+
vis_model=vis_model,
41+
vis_data=vis_data,
42+
gains=gains,
43+
antenna1=antenna1,
44+
antenna2=antenna2
45+
)
46+
47+
48+
def entry_point(data):
49+
vis_model = data['vis_model']
50+
vis_data = data['vis_data']
51+
gains = data['gains']
52+
antenna1 = data['antenna1']
53+
antenna2 = data['antenna2']
54+
55+
def fn(gains):
56+
res = compute_residual_BTC(vis_model=vis_model, vis_data=vis_data, gains=gains,
57+
antenna1=antenna1, antenna2=antenna2)
58+
return res
59+
60+
J_bare = JVPLinearOp(fn, linearize=False)
61+
J = J_bare(gains)
62+
R = fn(gains)
63+
g = J.matvec(R, adjoint=True)
64+
return J.matvec(J.matvec(g), adjoint=True)
65+
66+
67+
def build_sharded_entry_point(devices):
68+
mesh = create_mesh((len(devices),), ('B',), devices)
69+
P = PartitionSpec
70+
in_specs = dict(
71+
vis_model=P(None, 'B'),
72+
vis_data=P('B'),
73+
gains=P(),
74+
antenna1=P('B'),
75+
antenna2=P('B')
76+
)
77+
out_specs = P('B')
78+
79+
@partial(shard_map, mesh=mesh, in_specs=(in_specs,), out_specs=out_specs)
80+
def entry_point_sharded(local_data):
81+
return entry_point(local_data) # [_B, Tm, Cm, 2, 2]
82+
83+
return entry_point_sharded, mesh
84+
85+
86+
def main():
87+
cpus = jax.devices("cpu")
88+
# gpus = jax.devices("cuda")
89+
cpu = cpus[0]
90+
# gpu = gpus[0]
91+
92+
entry_point_jit = jax.jit(entry_point)
93+
sharded_entry_point, mesh = build_sharded_entry_point(cpus)
94+
sharded_entry_point_jit = jax.jit(sharded_entry_point)
95+
# Run benchmarking over number of calibration directions
96+
time_array = []
97+
d_array = []
98+
for D in range(1, 9):
99+
data = prepare_data(D, Ts=1, Tm=1, Cs=1, Cm=1)
100+
101+
with jax.default_device(cpu):
102+
data = jax.device_put(data)
103+
entry_point_jit_compiled = entry_point_jit.lower(data).compile()
104+
t0 = time.time()
105+
for _ in range(3):
106+
jax.block_until_ready(entry_point_jit_compiled(data))
107+
t1 = time.time()
108+
dt = (t1 - t0) / 3
109+
dsa_logger.info(f"BTC: Residual: CPU D={D}: {dt}")
110+
time_array.append(dt)
111+
d_array.append(D)
112+
113+
sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
114+
t0 = time.time()
115+
for _ in range(3):
116+
jax.block_until_ready(sharded_entry_point_jit_compiled(data))
117+
t1 = time.time()
118+
dt = (t1 - t0) / 3
119+
dsa_logger.info(f"BTC: Residual (sharded): CPU D={D}: {dt}")
120+
121+
# data = prepare_data(D, Ts=4, Tm=1, Cs=4, Cm=1)
122+
# with jax.default_device(cpu):
123+
# data = jax.device_put(data)
124+
# entry_point_jit_compiled = entry_point_jit.lower(data).compile()
125+
# t0 = time.time()
126+
# jax.block_until_ready(entry_point_jit_compiled(data))
127+
# t1 = time.time()
128+
# dsa_logger.info(f"BTC: Subtract (per-GPU): CPU D={D}: {t1 - t0}")
129+
#
130+
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
131+
# t0 = time.time()
132+
# for _ in range(1):
133+
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
134+
# t1 = time.time()
135+
# dt = (t1 - t0) / 1
136+
# dsa_logger.info(f"BTC: Subtract (per-GPU sharded): CPU D={D}: {dt}")
137+
#
138+
# data = prepare_data(D, Ts=4, Tm=1, Cs=40, Cm=1)
139+
# with jax.default_device(cpu):
140+
# data = jax.device_put(data)
141+
# entry_point_jit_compiled = entry_point_jit.lower(data).compile()
142+
# t0 = time.time()
143+
# jax.block_until_ready(entry_point_jit_compiled(data))
144+
# t1 = time.time()
145+
# dsa_logger.info(f"BTC: Subtract (all-GPU): CPU D={D}: {t1 - t0}")
146+
#
147+
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
148+
# t0 = time.time()
149+
# for _ in range(1):
150+
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
151+
# t1 = time.time()
152+
# dt = (t1 - t0) / 1
153+
# dsa_logger.info(f"BTC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
154+
155+
# Fit line to data using scipy
156+
time_array = np.array(time_array)
157+
d_array = np.array(d_array)
158+
from scipy.optimize import curve_fit
159+
160+
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
161+
dsa_logger.info(f"BTC: Fit: {popt}")
162+
163+
164+
if __name__ == '__main__':
165+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import os
2+
from functools import partial
3+
4+
os.environ['JAX_PLATFORMS'] = 'cpu'
5+
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
6+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
7+
8+
from jax._src.partition_spec import PartitionSpec
9+
from jax.experimental.shard_map import shard_map
10+
11+
from dsa2000_common.common.jax_utils import create_mesh
12+
from dsa2000_common.common.jvp_linear_op import JVPLinearOp
13+
from dsa2000_common.common.logging import dsa_logger
14+
from dsa2000_common.common.mixed_precision_utils import mp_policy
15+
import itertools
16+
import time
17+
from typing import Dict, Any
18+
19+
import jax
20+
import numpy as np
21+
22+
from dsa2000_cal.ops.residuals import compute_residual_TBC
23+
24+
25+
def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
26+
num_antennas = 2048
27+
baseline_pairs = np.asarray(list(itertools.combinations(range(num_antennas), 2)),
28+
dtype=np.int32)
29+
antenna1 = baseline_pairs[:, 0]
30+
antenna2 = baseline_pairs[:, 1] # [B]
31+
32+
sort_idxs = np.lexsort((antenna1, antenna2))
33+
antenna1 = antenna1[sort_idxs]
34+
antenna2 = antenna2[sort_idxs]
35+
36+
B = antenna1.shape[0]
37+
vis_model = np.zeros((D, Tm, B, Cm, 2, 2), dtype=mp_policy.vis_dtype)
38+
vis_data = np.zeros((Ts, B, Cs, 2, 2), dtype=mp_policy.vis_dtype)
39+
gains = np.zeros((D, Tm, num_antennas, Cm, 2, 2), dtype=mp_policy.gain_dtype)
40+
return dict(
41+
vis_model=vis_model,
42+
vis_data=vis_data,
43+
gains=gains,
44+
antenna1=antenna1,
45+
antenna2=antenna2
46+
)
47+
48+
49+
def entry_point(data):
50+
vis_model = data['vis_model']
51+
vis_data = data['vis_data']
52+
gains = data['gains']
53+
antenna1 = data['antenna1']
54+
antenna2 = data['antenna2']
55+
56+
def fn(gains):
57+
res = compute_residual_TBC(vis_model=vis_model, vis_data=vis_data, gains=gains,
58+
antenna1=antenna1, antenna2=antenna2)
59+
return res
60+
61+
J_bare = JVPLinearOp(fn, linearize=False)
62+
J = J_bare(gains)
63+
R = fn(gains)
64+
g = J.matvec(R, adjoint=True)
65+
return J.matvec(J.matvec(g), adjoint=True)
66+
67+
68+
def build_sharded_entry_point(devices):
69+
mesh = create_mesh((len(devices),), ('B',), devices)
70+
P = PartitionSpec
71+
in_specs = dict(
72+
vis_model=P(None, None, 'B'),
73+
vis_data=P(None, 'B'),
74+
gains=P(),
75+
antenna1=P('B'),
76+
antenna2=P('B')
77+
)
78+
out_specs = P('B')
79+
80+
@partial(shard_map, mesh=mesh, in_specs=(in_specs,), out_specs=out_specs)
81+
def entry_point_sharded(local_data):
82+
return entry_point(local_data) # [ Tm, _B, Cm, 2, 2]
83+
84+
return entry_point_sharded, mesh
85+
86+
87+
def main():
88+
cpus = jax.devices("cpu")
89+
# gpus = jax.devices("cuda")
90+
cpu = cpus[0]
91+
# gpu = gpus[0]
92+
93+
entry_point_jit = jax.jit(entry_point)
94+
sharded_entry_point, mesh = build_sharded_entry_point(cpus)
95+
sharded_entry_point_jit = jax.jit(sharded_entry_point)
96+
# Run benchmarking over number of calibration directions
97+
time_array = []
98+
shard_time_array = []
99+
d_array = []
100+
for D in range(1, 9):
101+
data = prepare_data(D, Ts=1, Tm=1, Cs=1, Cm=1)
102+
with jax.default_device(cpu):
103+
data = jax.device_put(data)
104+
entry_point_jit_compiled = entry_point_jit.lower(data).compile()
105+
t0 = time.time()
106+
for _ in range(3):
107+
jax.block_until_ready(entry_point_jit_compiled(data))
108+
t1 = time.time()
109+
dt = (t1 - t0) / 3
110+
dsa_logger.info(f"TBC: Residual: CPU D={D}: {dt}")
111+
time_array.append(dt)
112+
d_array.append(D)
113+
114+
sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
115+
t0 = time.time()
116+
for _ in range(3):
117+
jax.block_until_ready(sharded_entry_point_jit_compiled(data))
118+
t1 = time.time()
119+
dt = (t1 - t0) / 3
120+
dsa_logger.info(f"TBC: Residual (sharded): CPU D={D}: {dt}")
121+
shard_time_array.append(dt)
122+
#
123+
# data = prepare_data(D, Ts=4, Tm=1, Cs=4, Cm=1)
124+
# with jax.default_device(cpu):
125+
# data = jax.device_put(data)
126+
# entry_point_jit_compiled = entry_point_jit.lower(data).compile()
127+
# t0 = time.time()
128+
# jax.block_until_ready(entry_point_jit_compiled(data))
129+
# t1 = time.time()
130+
# dsa_logger.info(f"TBC: Subtract (per-GPU): CPU D={D}: {t1 - t0}")
131+
#
132+
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
133+
# t0 = time.time()
134+
# for _ in range(1):
135+
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
136+
# t1 = time.time()
137+
# dt = (t1 - t0) / 1
138+
# dsa_logger.info(f"TBC: Subtract (per-GPU sharded): CPU D={D}: {dt}")
139+
#
140+
# data = prepare_data(D, Ts=4, Tm=1, Cs=40, Cm=1)
141+
# with jax.default_device(cpu):
142+
# data = jax.device_put(data)
143+
# entry_point_jit_compiled = entry_point_jit.lower(data).compile()
144+
# t0 = time.time()
145+
# jax.block_until_ready(entry_point_jit_compiled(data))
146+
# t1 = time.time()
147+
# dsa_logger.info(f"TBC: Subtract (all-GPU): CPU D={D}: {t1 - t0}")
148+
#
149+
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
150+
# t0 = time.time()
151+
# for _ in range(1):
152+
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
153+
# t1 = time.time()
154+
# dt = (t1 - t0) / 1
155+
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
156+
157+
# Fit line to data using scipy
158+
time_array = np.array(time_array)
159+
d_array = np.array(d_array)
160+
from scipy.optimize import curve_fit
161+
162+
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
163+
dsa_logger.info(f"TBC: Fit: {popt}")
164+
165+
shard_time_array = np.array(shard_time_array)
166+
167+
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array)
168+
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
169+
170+
171+
if __name__ == '__main__':
172+
main()

0 commit comments

Comments
 (0)