Skip to content

Commit 9d1a997

Browse files
committed
Update python
1 parent 42067c1 commit 9d1a997

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

SAS_pytorch.ipynb

+39-11
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
"outputs": [],
6767
"source": [
6868
"def zero_pad(arr, device=device):\n",
69+
" '''\n",
70+
" Pad arr with zeros to double the size. First dim is assumed to be batch dim which\n",
71+
" won't be changed\n",
72+
" '''\n",
6973
" out_arr = torch.zeros(arr.shape[0], arr.shape[1] * 2, arr.shape[2] * 2, device=device, dtype=arr.dtype)\n",
7074
" \n",
7175
" as1 = (arr.shape[1] + 1) // 2\n",
@@ -74,7 +78,10 @@
7478
" return out_arr\n",
7579
"\n",
7680
"def zero_unpad(arr, original_shape):\n",
77-
" \n",
81+
" '''\n",
82+
" Strip off padding of arr with zeros to halve the size. First dim is assumed to be batch dim which\n",
83+
" won't be changed\n",
84+
" '''\n",
7885
" as1 = (original_shape[1] + 1) // 2\n",
7986
" as2 = (original_shape[2] + 1) // 2\n",
8087
" return arr[:, as1:as1 + original_shape[1], as2:as2 + original_shape[2]]\n"
@@ -86,7 +93,22 @@
8693
"metadata": {},
8794
"outputs": [],
8895
"source": [
89-
"def scalable_angular_spectrum(psi, z, lbd, L, device=device):\n",
96+
"def scalable_angular_spectrum(psi, z, lbd, L, device=device, skip_final_phase=True):\n",
97+
" '''\n",
98+
" Returns the complex electrical field psi propagated with the Scalable Angular Spectrum Method.\n",
99+
" \n",
100+
" Parameters:\n",
101+
" psi (torch.tensor): the quadratically shaped input field, with leading batch dimension\n",
102+
" z (number): propagation distance\n",
103+
" lbd (number): wavelength\n",
104+
" L (number): physical sidelength of the input field\n",
105+
" skip_final_phase=True: Skip final multiplication of phase factor. For M>2 undersampled,\n",
106+
" \n",
107+
" Returns:\n",
108+
" psi_final (torch.tensor): Propagated field\n",
109+
" Q (number): Output field size, corresponds to magnificiation * L\n",
110+
" \n",
111+
" '''\n",
90112
" N = psi.shape[-1]\n",
91113
" z_limit = (- 4 * L * np.sqrt(8*L**2 / N**2 + lbd**2) * np.sqrt(L**2 * 1 / (8 * L**2 + N**2 * lbd**2))\\\n",
92114
" / (lbd * (-1+2 * np.sqrt(2) * np.sqrt(L**2 * 1 / (8 * L**2 + N**2 * lbd**2)))))\n",
@@ -137,14 +159,22 @@
137159
" dq = lbd * z / L_new\n",
138160
" Q = dq * N * pad_factor\n",
139161
" \n",
162+
" q_y = torch.fft.ifftshift(torch.tensor(np.linspace(-Q/2, Q/2, N_new, endpoint=False),\n",
163+
" device=device).reshape(1, 1, N_new), dim=(-1))\n",
164+
" q_x = q_y.reshape(1, N_new, 1)\n",
140165
" \n",
141-
" # skip final H_2 phase\n",
142166
" H_1 = torch.exp(1j * k / (2 * z) * (x**2 + y**2))\n",
143-
" psi_p_final = torch.fft.fftshift(torch.fft.fft2(H_1 * psi_precomp), dim=(-1,-2))\n",
144167
"\n",
168+
" if skip_final_phase:\n",
169+
" psi_p_final = torch.fft.fftshift(torch.fft.fft2(H_1 * psi_precomp), dim=(-1,-2))\n",
170+
" else:\n",
171+
" H_2 = np.exp(1j * k * z) * torch.exp(1j * k / (2 * z) * (q_x**2 + q_y**2))\n",
172+
" psi_p_final = torch.fft.fftshift(H_2 * torch.fft.fft2(H_1 * psi_precomp), dim=(-1,-2))\n",
173+
" \n",
174+
" \n",
145175
" psi_final = zero_unpad(psi_p_final, psi.shape)\n",
146176
" \n",
147-
" return psi_final, Q \n"
177+
" return psi_final, Q / 2\n"
148178
]
149179
},
150180
{
@@ -178,7 +208,7 @@
178208
{
179209
"data": {
180210
"text/plain": [
181-
"<matplotlib.image.AxesImage at 0x7f13aae403d0>"
211+
"<matplotlib.image.AxesImage at 0x7fbe06e3f3a0>"
182212
]
183213
},
184214
"execution_count": 7,
@@ -203,14 +233,12 @@
203233
{
204234
"cell_type": "code",
205235
"execution_count": 8,
206-
"metadata": {
207-
"scrolled": false
208-
},
236+
"metadata": {},
209237
"outputs": [
210238
{
211239
"data": {
212240
"text/plain": [
213-
"<matplotlib.image.AxesImage at 0x7f13c43bdbe0>"
241+
"<matplotlib.image.AxesImage at 0x7fbe1c2431c0>"
214242
]
215243
},
216244
"execution_count": 8,
@@ -241,7 +269,7 @@
241269
"name": "stdout",
242270
"output_type": "stream",
243271
"text": [
244-
"5.52 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
272+
"5.6 ms ± 292 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
245273
]
246274
}
247275
],

0 commit comments

Comments
 (0)