|
66 | 66 | "outputs": [],
|
67 | 67 | "source": [
|
68 | 68 | "def zero_pad(arr, device=device):\n",
|
| 69 | + " '''\n", |
| 70 | + " Pad arr with zeros to double the size. First dim is assumed to be batch dim which\n", |
| 71 | + " won't be changed\n", |
| 72 | + " '''\n", |
69 | 73 | " out_arr = torch.zeros(arr.shape[0], arr.shape[1] * 2, arr.shape[2] * 2, device=device, dtype=arr.dtype)\n",
|
70 | 74 | " \n",
|
71 | 75 | " as1 = (arr.shape[1] + 1) // 2\n",
|
|
74 | 78 | " return out_arr\n",
|
75 | 79 | "\n",
|
76 | 80 | "def zero_unpad(arr, original_shape):\n",
|
77 |
| - " \n", |
| 81 | + " '''\n", |
| 82 | + " Strip off padding of arr with zeros to halve the size. First dim is assumed to be batch dim which\n", |
| 83 | + " won't be changed\n", |
| 84 | + " '''\n", |
78 | 85 | " as1 = (original_shape[1] + 1) // 2\n",
|
79 | 86 | " as2 = (original_shape[2] + 1) // 2\n",
|
80 | 87 | " return arr[:, as1:as1 + original_shape[1], as2:as2 + original_shape[2]]\n"
|
|
86 | 93 | "metadata": {},
|
87 | 94 | "outputs": [],
|
88 | 95 | "source": [
|
89 |
| - "def scalable_angular_spectrum(psi, z, lbd, L, device=device):\n", |
| 96 | + "def scalable_angular_spectrum(psi, z, lbd, L, device=device, skip_final_phase=True):\n", |
| 97 | + " '''\n", |
| 98 | + " Returns the complex electrical field psi propagated with the Scalable Angular Spectrum Method.\n", |
| 99 | + " \n", |
| 100 | + " Parameters:\n", |
| 101 | + " psi (torch.tensor): the quadratically shaped input field, with leading batch dimension\n", |
| 102 | + " z (number): propagation distance\n", |
| 103 | + " lbd (number): wavelength\n", |
| 104 | + " L (number): physical sidelength of the input field\n", |
| 105 | + " skip_final_phase=True: Skip final multiplication of phase factor. For M>2 undersampled,\n", |
| 106 | + " \n", |
| 107 | + " Returns:\n", |
| 108 | + " psi_final (torch.tensor): Propagated field\n", |
| 109 | + " Q (number): Output field size, corresponds to magnificiation * L\n", |
| 110 | + " \n", |
| 111 | + " '''\n", |
90 | 112 | " N = psi.shape[-1]\n",
|
91 | 113 | " z_limit = (- 4 * L * np.sqrt(8*L**2 / N**2 + lbd**2) * np.sqrt(L**2 * 1 / (8 * L**2 + N**2 * lbd**2))\\\n",
|
92 | 114 | " / (lbd * (-1+2 * np.sqrt(2) * np.sqrt(L**2 * 1 / (8 * L**2 + N**2 * lbd**2)))))\n",
|
|
137 | 159 | " dq = lbd * z / L_new\n",
|
138 | 160 | " Q = dq * N * pad_factor\n",
|
139 | 161 | " \n",
|
| 162 | + " q_y = torch.fft.ifftshift(torch.tensor(np.linspace(-Q/2, Q/2, N_new, endpoint=False),\n", |
| 163 | + " device=device).reshape(1, 1, N_new), dim=(-1))\n", |
| 164 | + " q_x = q_y.reshape(1, N_new, 1)\n", |
140 | 165 | " \n",
|
141 |
| - " # skip final H_2 phase\n", |
142 | 166 | " H_1 = torch.exp(1j * k / (2 * z) * (x**2 + y**2))\n",
|
143 |
| - " psi_p_final = torch.fft.fftshift(torch.fft.fft2(H_1 * psi_precomp), dim=(-1,-2))\n", |
144 | 167 | "\n",
|
| 168 | + " if skip_final_phase:\n", |
| 169 | + " psi_p_final = torch.fft.fftshift(torch.fft.fft2(H_1 * psi_precomp), dim=(-1,-2))\n", |
| 170 | + " else:\n", |
| 171 | + " H_2 = np.exp(1j * k * z) * torch.exp(1j * k / (2 * z) * (q_x**2 + q_y**2))\n", |
| 172 | + " psi_p_final = torch.fft.fftshift(H_2 * torch.fft.fft2(H_1 * psi_precomp), dim=(-1,-2))\n", |
| 173 | + " \n", |
| 174 | + " \n", |
145 | 175 | " psi_final = zero_unpad(psi_p_final, psi.shape)\n",
|
146 | 176 | " \n",
|
147 |
| - " return psi_final, Q \n" |
| 177 | + " return psi_final, Q / 2\n" |
148 | 178 | ]
|
149 | 179 | },
|
150 | 180 | {
|
|
178 | 208 | {
|
179 | 209 | "data": {
|
180 | 210 | "text/plain": [
|
181 |
| - "<matplotlib.image.AxesImage at 0x7f13aae403d0>" |
| 211 | + "<matplotlib.image.AxesImage at 0x7fbe06e3f3a0>" |
182 | 212 | ]
|
183 | 213 | },
|
184 | 214 | "execution_count": 7,
|
|
203 | 233 | {
|
204 | 234 | "cell_type": "code",
|
205 | 235 | "execution_count": 8,
|
206 |
| - "metadata": { |
207 |
| - "scrolled": false |
208 |
| - }, |
| 236 | + "metadata": {}, |
209 | 237 | "outputs": [
|
210 | 238 | {
|
211 | 239 | "data": {
|
212 | 240 | "text/plain": [
|
213 |
| - "<matplotlib.image.AxesImage at 0x7f13c43bdbe0>" |
| 241 | + "<matplotlib.image.AxesImage at 0x7fbe1c2431c0>" |
214 | 242 | ]
|
215 | 243 | },
|
216 | 244 | "execution_count": 8,
|
|
241 | 269 | "name": "stdout",
|
242 | 270 | "output_type": "stream",
|
243 | 271 | "text": [
|
244 |
| - "5.52 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" |
| 272 | + "5.6 ms ± 292 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" |
245 | 273 | ]
|
246 | 274 | }
|
247 | 275 | ],
|
|
0 commit comments