Skip to content

Commit de11eed

Browse files
authored
Fix reflected elementwise binary ops (#68)
1 parent ad672e4 commit de11eed

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

examples/shallow_water.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def rhs(u, v, e):
233233
# potential vorticity
234234
dudy[:, 1:-1] = (u[:, 1:] - u[:, :-1]) / dy
235235
dvdx[1:-1, :] = (v[1:, :] - v[:-1, :]) / dx
236-
q[:, :] = (dvdx - dudy + coriolis) / H_at_f
236+
q[:, :] = (coriolis - dudy + dvdx) / H_at_f
237237

238238
# Advection of potential vorticity, Arakawa and Hsu (1990)
239239
# Define alpha, beta, gamma, delta for each cell in T points

src/EWBinOp.cpp

+29-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ static ::imex::ptensor::EWBinOpId ddpt2mlir(const EWBinOpId bop) {
8484
}
8585
}
8686

87+
bool is_reflected_op(EWBinOpId op) {
88+
switch (op) {
89+
case __RADD__:
90+
case __RAND__:
91+
case __RFLOORDIV__:
92+
case __RLSHIFT__:
93+
case __RMOD__:
94+
case __RMUL__:
95+
case __ROR__:
96+
case __RPOW__:
97+
case __RRSHIFT__:
98+
case __RSUB__:
99+
case __RTRUEDIV__:
100+
case __RXOR__:
101+
return true;
102+
default:
103+
return false;
104+
}
105+
}
106+
87107
struct DeferredEWBinOp : public Deferred {
88108
id_type _a;
89109
id_type _b;
@@ -107,8 +127,16 @@ struct DeferredEWBinOp : public Deferred {
107127
::imex::ptensor::toMLIR(builder, DDPT::jit::getPTDType(_dtype));
108128
auto outTyp = aTyp.cloneWith(shape(), outElemType);
109129

130+
::mlir::Value one, two;
131+
if (is_reflected_op(_op)) {
132+
one = bv;
133+
two = av;
134+
} else {
135+
one = av;
136+
two = bv;
137+
}
110138
auto bop = builder.create<::imex::ptensor::EWBinOp>(
111-
loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv);
139+
loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), one, two);
112140

113141
dm.addVal(this->guid(), bop,
114142
[this](uint64_t rank, void *l_allocated, void *l_aligned,

test/test_ewb.py

+14
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,17 @@ def test_add_typecast(self):
182182
assert c.dtype == ctype
183183
c2 = dt.to_numpy(c)
184184
assert numpy.allclose(c2, [1, 2, 3, 4, 5, 6, 7, 8])
185+
186+
def test_reflected_sub(self):
187+
a = dt.full((4,), 4, dtype=dt.float64)
188+
b = 10.0
189+
c = b - a
190+
c2 = dt.to_numpy(c)
191+
assert numpy.allclose(c2, [6, 6, 6, 6])
192+
193+
def test_reflected_div(self):
194+
a = dt.full((4,), 2, dtype=dt.float64)
195+
b = 10.0
196+
c = b / a
197+
c2 = dt.to_numpy(c)
198+
assert numpy.allclose(c2, [5, 5, 5, 5])

0 commit comments

Comments
 (0)