|
33 | 33 | }
|
34 | 34 | ],
|
35 | 35 | "source": [
|
36 |
| - "import ipsuite as ips\n", |
37 | 36 | "from zntrack.utils import cwd_temp_dir\n",
|
38 | 37 | "\n",
|
| 38 | + "import ipsuite as ips\n", |
| 39 | + "\n", |
39 | 40 | "temp_dir = cwd_temp_dir()\n",
|
40 | 41 | "\n",
|
41 |
| - "import ipsuite as ips\n", |
42 | 42 | "\n",
|
43 | 43 | "import os\n",
|
| 44 | + "\n", |
44 | 45 | "from ase import units\n",
|
45 | 46 | "from ase.calculators.emt import EMT\n",
|
46 | 47 | "from ase.io.trajectory import TrajectoryWriter\n",
|
47 | 48 | "from ase.lattice.cubic import FaceCenteredCubic\n",
|
48 |
| - "from ase.md.velocitydistribution import MaxwellBoltzmannDistribution\n", |
49 | 49 | "from ase.md.langevin import Langevin\n",
|
50 |
| - "from ase.visualize import view\n" |
| 50 | + "from ase.md.velocitydistribution import MaxwellBoltzmannDistribution\n", |
| 51 | + "from ase.visualize import view" |
51 | 52 | ]
|
52 | 53 | },
|
53 | 54 | {
|
|
119 | 120 | "# Set up a crystal\n",
|
120 | 121 | "atoms = FaceCenteredCubic(\n",
|
121 | 122 | " directions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n",
|
122 |
| - " symbol='Cu',\n", |
| 123 | + " symbol=\"Cu\",\n", |
123 | 124 | " size=(size, size, size),\n",
|
124 |
| - " pbc=True\n", |
125 |
| - ")\n" |
| 125 | + " pbc=True,\n", |
| 126 | + ")" |
126 | 127 | ]
|
127 | 128 | },
|
128 | 129 | {
|
|
225 | 226 | "metadata": {},
|
226 | 227 | "outputs": [],
|
227 | 228 | "source": [
|
228 |
| - "trajectory.load() # requires the project to have been run" |
| 229 | + "trajectory.load() # requires the project to have been run" |
229 | 230 | ]
|
230 | 231 | },
|
231 | 232 | {
|
|
296 | 297 | ],
|
297 | 298 | "source": [
|
298 | 299 | "with project:\n",
|
299 |
| - " random_test_selection = ips.RandomSelection(data=trajectory, n_configurations=10, name=\"random_test_selection\")\n", |
300 |
| - " random_val_selection = ips.RandomSelection(data=random_test_selection.excluded_atoms, n_configurations=15, name=\"random_val_selection\")\n", |
301 |
| - " random_train_selection = ips.RandomSelection(data=random_val_selection.excluded_atoms, n_configurations=75, name=\"random_train_selection\")\n", |
| 300 | + " random_test_selection = ips.RandomSelection(\n", |
| 301 | + " data=trajectory, n_configurations=10, name=\"random_test_selection\"\n", |
| 302 | + " )\n", |
| 303 | + " random_val_selection = ips.RandomSelection(\n", |
| 304 | + " data=random_test_selection.excluded_atoms,\n", |
| 305 | + " n_configurations=15,\n", |
| 306 | + " name=\"random_val_selection\",\n", |
| 307 | + " )\n", |
| 308 | + " random_train_selection = ips.RandomSelection(\n", |
| 309 | + " data=random_val_selection.excluded_atoms,\n", |
| 310 | + " n_configurations=75,\n", |
| 311 | + " name=\"random_train_selection\",\n", |
| 312 | + " )\n", |
302 | 313 | "project.repro()"
|
303 | 314 | ]
|
304 | 315 | },
|
|
416 | 427 | "with ips.Project() as project:\n",
|
417 | 428 | " trajectory = ips.AddData(file=traj_path, name=\"trajectory\")\n",
|
418 | 429 | " test_split = ips.SplitSelection(data=trajectory, split=0.1, name=\"test_split\")\n",
|
419 |
| - " val_split = ips.SplitSelection(data=test_split.excluded_atoms, split=0.17, name=\"val_split\") # 0.15 / 0.9 * 1.0 \\approx 0.17\n", |
420 |
| - " train_split = val_split.excluded_atoms # 0.8 of the total data\n", |
| 430 | + " val_split = ips.SplitSelection(\n", |
| 431 | + " data=test_split.excluded_atoms, split=0.17, name=\"val_split\"\n", |
| 432 | + " ) # 0.15 / 0.9 * 1.0 \\approx 0.17\n", |
| 433 | + " train_split = val_split.excluded_atoms # 0.8 of the total data\n", |
421 | 434 | "\n",
|
422 |
| - " test_data = ips.UniformTemporalSelection(data=test_split, n_configurations=10, name=\"test_data\")\n", |
423 |
| - " val_data = ips.UniformTemporalSelection(data=val_split, n_configurations=15, name=\"val_data\")\n", |
424 |
| - " train_data = ips.UniformEnergeticSelection(data=train_split, n_configurations=80, name=\"train_data\")\n", |
| 435 | + " test_data = ips.UniformTemporalSelection(\n", |
| 436 | + " data=test_split, n_configurations=10, name=\"test_data\"\n", |
| 437 | + " )\n", |
| 438 | + " val_data = ips.UniformTemporalSelection(\n", |
| 439 | + " data=val_split, n_configurations=15, name=\"val_data\"\n", |
| 440 | + " )\n", |
| 441 | + " train_data = ips.UniformEnergeticSelection(\n", |
| 442 | + " data=train_split, n_configurations=80, name=\"train_data\"\n", |
| 443 | + " )\n", |
425 | 444 | "\n",
|
426 | 445 | "project.repro()"
|
427 | 446 | ]
|
|
0 commit comments