Skip to content

Commit 3972635

Browse files
author
Guillaume Fraux
committed
Add a few basic functions for cells
1 parent 8806e76 commit 3972635

File tree

5 files changed

+226
-7
lines changed

5 files changed

+226
-7
lines changed

Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ version = "0.1.0"
44
authors = ["Guillaume Fraux <[email protected]>"]
55
build = "build.rs"
66

7+
[lib]
8+
crate-type = ["cdylib"]
9+
710
[dependencies]
811
lumol-core = {git = "https://github.com/lumol-org/lumol", rev = "788c94e"}
912
cpython = {git = "https://github.com/dgrunwald/rust-cpython/", rev = "4773d2e3"}
1013

11-
[lib]
12-
crate-type = ["cdylib"]
14+
[dev-dependencies]
15+
approx = "*"

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
extern crate cpython;
55
extern crate lumol;
66

7+
#[cfg(test)]
8+
#[macro_use]
9+
extern crate approx;
10+
711
#[macro_use]
812
mod macros;
913
mod error;

src/macros.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ macro_rules! create_instance {
1313

1414
macro_rules! py_run_with {
1515
($py: ident, $($obj: ident),+; $($code: expr),+ $(,)*) => ({
16-
const ASSERT_RAISES_PY: &'static str = "
16+
const ADDITIONAL_DEFINITIONS: &'static str = "
1717
def assert_raises(callable, *args, **kwargs):
1818
throw = True
1919
try:
@@ -22,6 +22,13 @@ def assert_raises(callable, *args, **kwargs):
2222
except LumolError:
2323
pass
2424
assert throw
25+
26+
def assert_approx_eq(a, b):
27+
# the abs function is not available when running from Rust
28+
abs = a - b
29+
if abs < 0:
30+
abs = -abs
31+
assert abs < 1e-12
2532
";
2633
use cpython::PyDict;
2734
let locals = PyDict::new($py);
@@ -33,12 +40,9 @@ def assert_raises(callable, *args, **kwargs):
3340
let error = $py.get_type::<$crate::error::LumolError>();
3441
globals.set_item($py, "LumolError", error).unwrap();
3542

36-
py_run!($py, globals, locals, ASSERT_RAISES_PY);
43+
py_run!($py, globals, locals, ADDITIONAL_DEFINITIONS);
3744
py_run!($py, globals, locals, $($code),+);
3845
});
39-
($py: ident, $obj: ident; $($code: expr),+,) => (
40-
py_run_with!($py, $obj; $($code),+);
41-
)
4246
}
4347

4448
macro_rules! py_run {

src/systems/cell.rs

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
use cpython::{PyObject, PyResult, ToPyObject, PythonObject};
2+
use std::cell::RefCell;
3+
use lumol;
4+
5+
use cpython::py_class::CompareOp;
6+
7+
use traits::Callback;
8+
9+
register!(|py, m| {
10+
try!(m.add_class::<CellShape>(py));
11+
try!(m.add_class::<UnitCell>(py));
12+
Ok(())
13+
});
14+
15+
py_class!(class CellShape |py| {
16+
data shape: lumol::sys::CellShape;
17+
18+
@classmethod def triclinic(_cls) -> PyResult<CellShape> {
19+
CellShape::create_instance(py, lumol::sys::CellShape::Triclinic)
20+
}
21+
22+
@classmethod def orthorhombic(_cls) -> PyResult<CellShape> {
23+
CellShape::create_instance(py, lumol::sys::CellShape::Orthorhombic)
24+
}
25+
26+
@classmethod def infinite(_cls) -> PyResult<CellShape> {
27+
CellShape::create_instance(py, lumol::sys::CellShape::Infinite)
28+
}
29+
30+
def __repr__(&self) -> PyResult<String> {
31+
let repr = match *self.shape(py) {
32+
lumol::sys::CellShape::Infinite => "CellShape.Infinite",
33+
lumol::sys::CellShape::Orthorhombic => "CellShape.Orthorhombic",
34+
lumol::sys::CellShape::Triclinic => "CellShape.Triclinic",
35+
};
36+
Ok(repr.into())
37+
}
38+
39+
def __richcmp__(&self, other: &PyObject, op: CompareOp) -> PyResult<PyObject> {
40+
let other = match other.extract::<CellShape>(py) {
41+
Ok(other) => other,
42+
Err(_) => return Ok(py.NotImplemented())
43+
};
44+
match op {
45+
CompareOp::Eq => {
46+
let result = self.shape(py) == other.shape(py);
47+
Ok(result.to_py_object(py).into_object())
48+
}
49+
CompareOp::Ne => {
50+
let result = self.shape(py) != other.shape(py);
51+
Ok(result.to_py_object(py).into_object())
52+
}
53+
_ => Ok(py.NotImplemented())
54+
}
55+
}
56+
});
57+
58+
py_class!(class UnitCell |py| {
59+
data cell: Box<Callback<lumol::sys::UnitCell>>;
60+
def __new__(_cls,
61+
a: f64 = 0.0, b: f64 = a, c: f64 = a,
62+
alpha: f64=90.0, beta: f64=90.0, gamma: f64=90.0) -> PyResult<UnitCell> {
63+
let cell = if a == 0.0 && b == 0.0 && c == 0.0 {
64+
lumol::sys::UnitCell::new()
65+
} else if alpha == 90.0 && alpha == 90.0 && alpha == 90.0 {
66+
lumol::sys::UnitCell::ortho(a, b, c)
67+
} else {
68+
lumol::sys::UnitCell::triclinic(a, b, c, alpha, beta, gamma)
69+
};
70+
UnitCell::create_instance(py, Box::new(RefCell::new(cell)))
71+
}
72+
73+
def a(&self) -> PyResult<f64> {
74+
let mut a = 0.0;
75+
self.cell(py).with_ref(&mut |cell| a = cell.a());
76+
Ok(a)
77+
}
78+
79+
def b(&self) -> PyResult<f64> {
80+
let mut b = 0.0;
81+
self.cell(py).with_ref(&mut |cell| b = cell.b());
82+
Ok(b)
83+
}
84+
85+
def c(&self) -> PyResult<f64> {
86+
let mut c = 0.0;
87+
self.cell(py).with_ref(&mut |cell| c = cell.c());
88+
Ok(c)
89+
}
90+
91+
def alpha(&self) -> PyResult<f64> {
92+
let mut alpha = 0.0;
93+
self.cell(py).with_ref(&mut |cell| alpha = cell.alpha());
94+
Ok(alpha)
95+
}
96+
97+
def beta(&self) -> PyResult<f64> {
98+
let mut beta = 0.0;
99+
self.cell(py).with_ref(&mut |cell| beta = cell.beta());
100+
Ok(beta)
101+
}
102+
103+
def gamma(&self) -> PyResult<f64> {
104+
let mut gamma = 0.0;
105+
self.cell(py).with_ref(&mut |cell| gamma = cell.gamma());
106+
Ok(gamma)
107+
}
108+
109+
def shape(&self) -> PyResult<CellShape> {
110+
let mut shape = lumol::sys::CellShape::Infinite;
111+
self.cell(py).with_ref(&mut |cell| shape = cell.shape());
112+
CellShape::create_instance(py, shape)
113+
}
114+
});
115+
116+
#[cfg(test)]
117+
mod tests {
118+
mod rust {
119+
use cpython::Python;
120+
use super::super::UnitCell;
121+
122+
#[test]
123+
fn constructors() {
124+
let gil = Python::acquire_gil();
125+
let py = gil.python();
126+
127+
let cell = create_instance!(py, UnitCell);
128+
assert_ulps_eq!(cell.a(py).unwrap(), 0.0);
129+
assert_ulps_eq!(cell.b(py).unwrap(), 0.0);
130+
assert_ulps_eq!(cell.c(py).unwrap(), 0.0);
131+
assert_ulps_eq!(cell.alpha(py).unwrap(), 90.0);
132+
assert_ulps_eq!(cell.beta(py).unwrap(), 90.0);
133+
assert_ulps_eq!(cell.gamma(py).unwrap(), 90.0);
134+
135+
let cell = create_instance!(py, UnitCell, (3.0, 4.0, 5.0));
136+
assert_ulps_eq!(cell.a(py).unwrap(), 3.0);
137+
assert_ulps_eq!(cell.b(py).unwrap(), 4.0);
138+
assert_ulps_eq!(cell.c(py).unwrap(), 5.0);
139+
assert_ulps_eq!(cell.alpha(py).unwrap(), 90.0);
140+
assert_ulps_eq!(cell.beta(py).unwrap(), 90.0);
141+
assert_ulps_eq!(cell.gamma(py).unwrap(), 90.0);
142+
143+
let cell = create_instance!(py, UnitCell, (3.0, 4.0, 5.0, 80.0, 90.0, 100.0));
144+
assert_ulps_eq!(cell.a(py).unwrap(), 3.0);
145+
assert_ulps_eq!(cell.b(py).unwrap(), 4.0);
146+
assert_ulps_eq!(cell.c(py).unwrap(), 5.0);
147+
assert_ulps_eq!(cell.alpha(py).unwrap(), 80.0);
148+
assert_ulps_eq!(cell.beta(py).unwrap(), 90.0);
149+
assert_ulps_eq!(cell.gamma(py).unwrap(), 100.0);
150+
}
151+
}
152+
153+
mod python {
154+
use cpython::Python;
155+
use super::super::{UnitCell, CellShape};
156+
157+
#[test]
158+
fn constructors() {
159+
#![allow(non_snake_case)]
160+
161+
let gil = Python::acquire_gil();
162+
let py = gil.python();
163+
let UnitCell = py.get_type::<UnitCell>();
164+
let CellShape = py.get_type::<CellShape>();
165+
166+
py_run_with!(py, UnitCell, CellShape;
167+
"cell = UnitCell()",
168+
"assert cell.a() == 0.0",
169+
"assert cell.b() == 0.0",
170+
"assert cell.c() == 0.0",
171+
"assert cell.alpha() == 90.0",
172+
"assert cell.beta() == 90.0",
173+
"assert cell.gamma() == 90.0",
174+
"assert cell.shape() == CellShape.infinite()",
175+
);
176+
177+
let UnitCell = py.get_type::<UnitCell>();
178+
let CellShape = py.get_type::<CellShape>();
179+
180+
py_run_with!(py, UnitCell, CellShape;
181+
"cell = UnitCell(3, 4, 5)",
182+
"assert cell.a() == 3.0",
183+
"assert cell.b() == 4.0",
184+
"assert cell.c() == 5.0",
185+
"assert cell.alpha() == 90.0",
186+
"assert cell.beta() == 90.0",
187+
"assert cell.gamma() == 90.0",
188+
"assert cell.shape() == CellShape.orthorhombic()",
189+
);
190+
191+
let UnitCell = py.get_type::<UnitCell>();
192+
let CellShape = py.get_type::<CellShape>();
193+
194+
py_run_with!(py, UnitCell, CellShape;
195+
"cell = UnitCell(3, 4, 5, 80, 90, 100)",
196+
"assert_approx_eq(cell.a(), 3.0)",
197+
"assert_approx_eq(cell.b(), 4.0)",
198+
"assert_approx_eq(cell.c(), 5.0)",
199+
"assert_approx_eq(cell.alpha(), 80.0)",
200+
"assert_approx_eq(cell.beta(), 90.0)",
201+
"assert_approx_eq(cell.gamma(), 100.0)",
202+
"assert cell.shape() == CellShape.triclinic()",
203+
);
204+
}
205+
}
206+
}

src/systems/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
mod particle;
2+
mod cell;
23

34
register!(|py, m| {
45
try!(particle::register(py, m));
6+
try!(cell::register(py, m));
57
Ok(())
68
});

0 commit comments

Comments
 (0)