-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathday06.py
More file actions
113 lines (91 loc) · 2.6 KB
/
day06.py
File metadata and controls
113 lines (91 loc) · 2.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from dataclasses import dataclass
import numpy as np
import nvtx
from numba import cuda
from numba.cuda.cudadrv.devicearray import DeviceNDArray
@dataclass
class Data:
n: int
numbers: DeviceNDArray
ops: DeviceNDArray
offsets: DeviceNDArray
outputs: DeviceNDArray
@nvtx.annotate("Parse Input")
def parse() -> Data:
with open("inputs/day06.in") as f:
data = f.read()
W = 1 + data.index("\n")
buf = np.frombuffer(data.encode(), dtype=np.uint8)
H = buf.size // W
buf = buf.reshape((H, W))[:]
numbers = buf[:-1]
ops = buf[-1][:-1]
mask = np.all(np.logical_or(numbers == ord(" "), numbers == ord("\n")), axis=0)
ind = np.where(mask)[0]
prefix = np.zeros(len(ind) + 1, dtype=np.int32)
prefix[1:] = ind + 1
n = prefix.size - 1
return Data(
n,
cuda.to_device(numbers),
cuda.to_device(ops),
cuda.to_device(prefix),
cuda.device_array(n, dtype=np.int64),
)
@cuda.jit
def kernel(
n: int,
nums: DeviceNDArray,
ops: DeviceNDArray,
offsets: DeviceNDArray,
out: DeviceNDArray,
by_col: bool,
):
pid = cuda.grid(1)
if pid >= n:
return
start = offsets[pid]
stop = offsets[pid + 1]
op = ops[start]
operands = nums.shape[0]
total = 0 if op == ord("+") else 1
r1 = range(operands)
r2 = range(start, stop) if not by_col else range(stop - 1, start - 1, -1)
outer = r2 if by_col else r1
inner = r1 if by_col else r2
for r in outer:
num = 0
for c in inner:
v = nums[c, r] if by_col else nums[r, c]
if ord("0") < v and v <= ord("9"):
num = num * 10 + (v - ord("0"))
if num != 0:
if op == ord("+"):
total += num
else:
total *= num
out[pid] = total
@nvtx.annotate("Part 1")
def part1(data: Data) -> int:
THREADS_PER_BLOCK = 128
blocks = (data.n + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
kernel[blocks, THREADS_PER_BLOCK](
data.n, data.numbers, data.ops, data.offsets, data.outputs, False
)
return data.outputs.copy_to_host().sum()
@nvtx.annotate("Part 2")
def part2(data: Data) -> int:
THREADS_PER_BLOCK = 128
blocks = (data.n + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
kernel[blocks, THREADS_PER_BLOCK](
data.n, data.numbers, data.ops, data.offsets, data.outputs, True
)
return data.outputs.copy_to_host().sum()
@nvtx.annotate("Day 06")
def main():
data = parse()
r1 = part1(data)
r2 = part2(data)
return r1, r2
if __name__ == "__main__":
print(*main())