Skip to content

Commit 1358a67

Browse files
author
Jake Moss
committed
Update tests, make "division not support" exception lowest priority
1 parent a2573c1 commit 1358a67

File tree

3 files changed

+73
-60
lines changed

3 files changed

+73
-60
lines changed

src/flint/flint_base/flint_base.pyx

+12-12
Original file line numberDiff line numberDiff line change
@@ -581,94 +581,94 @@ cdef class flint_mpoly(flint_elem):
581581

582582
def __divmod__(self, other):
583583
if typecheck(other, type(self)):
584-
self._division_check(other)
585584
self.context().compatible_context_check(other.context())
585+
self._division_check(other)
586586
return self._divmod_mpoly_(other)
587587

588588
other = self.context().any_as_scalar(other)
589589
if other is NotImplemented:
590590
return NotImplemented
591591

592-
self._division_check(other)
593592
other = self.context().scalar_as_mpoly(other)
593+
self._division_check(other)
594594
return self._divmod_mpoly_(other)
595595

596596
def __rdivmod__(self, other):
597597
other = self.context().any_as_scalar(other)
598598
if other is NotImplemented:
599599
return NotImplemented
600600

601-
self._division_check(self)
602601
other = self.context().scalar_as_mpoly(other)
602+
other._division_check(self)
603603
return other._divmod_mpoly_(self)
604604

605605
def __truediv__(self, other):
606606
if typecheck(other, type(self)):
607-
self._division_check(other)
608607
self.context().compatible_context_check(other.context())
608+
self._division_check(other)
609609
return self._truediv_mpoly_(other)
610610

611611
other = self.context().any_as_scalar(other)
612612
if other is NotImplemented:
613613
return NotImplemented
614614

615-
self._division_check(other)
616615
other = self.context().scalar_as_mpoly(other)
616+
self._division_check(other)
617617
return self._truediv_mpoly_(other)
618618

619619
def __rtruediv__(self, other):
620620
other = self.context().any_as_scalar(other)
621621
if other is NotImplemented:
622622
return NotImplemented
623623

624-
self._division_check(self)
625624
other = self.context().scalar_as_mpoly(other)
625+
other._division_check(self)
626626
return other._truediv_mpoly_(self)
627627

628628
def __floordiv__(self, other):
629629
if typecheck(other, type(self)):
630-
self._division_check(other)
631630
self.context().compatible_context_check(other.context())
631+
self._division_check(other)
632632
return self._floordiv_mpoly_(other)
633633

634634
other = self.context().any_as_scalar(other)
635635
if other is NotImplemented:
636636
return NotImplemented
637637

638-
self._division_check(other)
639638
other = self.context().scalar_as_mpoly(other)
639+
self._division_check(other)
640640
return self._floordiv_mpoly_(other)
641641

642642
def __rfloordiv__(self, other):
643643
other = self.context().any_as_scalar(other)
644644
if other is NotImplemented:
645645
return NotImplemented
646646

647-
self._division_check(self)
648647
other = self.context().scalar_as_mpoly(other)
648+
other._division_check(self)
649649
return other._floordiv_mpoly_(self)
650650

651651
def __mod__(self, other):
652652
if typecheck(other, type(self)):
653-
self._division_check(other)
654653
self.context().compatible_context_check(other.context())
654+
self._division_check(other)
655655
return self._mod_mpoly_(other)
656656

657657
other = self.context().any_as_scalar(other)
658658
if other is NotImplemented:
659659
return NotImplemented
660660

661-
self._division_check(other)
662661
other = self.context().scalar_as_mpoly(other)
662+
self._division_check(other)
663663
return self._mod_mpoly_(other)
664664

665665
def __rmod__(self, other):
666666
other = self.context().any_as_scalar(other)
667667
if other is NotImplemented:
668668
return NotImplemented
669669

670-
self._division_check(self)
671670
other = self.context().scalar_as_mpoly(other)
671+
other._division_check(self)
672672
return other._mod_mpoly_(self)
673673

674674
def __contains__(self, x):

src/flint/test/test_all.py

+55-48
Original file line numberDiff line numberDiff line change
@@ -2782,6 +2782,12 @@ def _all_mpolys():
27822782
def test_mpolys():
27832783
for P, get_context, S, is_field in _all_mpolys():
27842784

2785+
# Division under modulo will raise a flint exception if something is not invertible, crashing the program. We
2786+
# can't tell before what is invertible and what is not before hand so we always raise an exception, except for
2787+
# fmpz_mpoly, that returns an bool noting if the division is exact or not.
2788+
division_not_supported = P is not flint.fmpz_mpoly and not is_field
2789+
characteristic_zero = not (P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly)
2790+
27852791
ctx = get_context(nvars=2)
27862792

27872793
assert raises(lambda: get_context(nvars=2, ordering="bad"), TypeError)
@@ -3045,7 +3051,7 @@ def quick_poly():
30453051
assert raises(lambda: quick_poly().imul(P(ctx=ctx1)), IncompatibleContextError)
30463052
assert raises(lambda: quick_poly().imul(None), NotImplementedError)
30473053

3048-
if (P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly) and not ctx.is_prime():
3054+
if division_not_supported:
30493055
assert raises(lambda: quick_poly() // mpoly({(1, 1): 1}), DomainError)
30503056
assert raises(lambda: quick_poly() % mpoly({(1, 1): 1}), DomainError)
30513057
assert raises(lambda: divmod(quick_poly(), mpoly({(1, 1): 1})), DomainError)
@@ -3056,9 +3062,6 @@ def quick_poly():
30563062
assert divmod(quick_poly(), mpoly({(1, 1): 1})) \
30573063
== (mpoly({(1, 1): 4}), mpoly({(1, 0): 3, (0, 1): 2, (0, 0): 1}))
30583064

3059-
if (P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly) and not ctx.is_prime():
3060-
pass
3061-
else:
30623065
assert 1 / P(1, ctx=ctx) == P(1, ctx=ctx)
30633066
assert quick_poly() / 1 == quick_poly()
30643067
assert quick_poly() // 1 == quick_poly()
@@ -3078,9 +3081,7 @@ def quick_poly():
30783081

30793082
f = mpoly({(1, 1): 4, (0, 0): 1})
30803083
g = mpoly({(0, 1): 2, (1, 0): 2})
3081-
if (P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly) and not ctx.is_prime():
3082-
pass
3083-
else:
3084+
if not division_not_supported:
30843085
assert 1 // quick_poly() == P(ctx=ctx)
30853086
assert 1 % quick_poly() == P(1, ctx=ctx)
30863087
assert divmod(1, quick_poly()) == (P(ctx=ctx), P(1, ctx=ctx))
@@ -3089,43 +3090,44 @@ def quick_poly():
30893090
assert S(1) % quick_poly() == P(1, ctx=ctx)
30903091
assert divmod(S(1), quick_poly()) == (P(ctx=ctx), P(1, ctx=ctx))
30913092

3092-
assert raises(lambda: quick_poly() / None, TypeError)
3093-
assert raises(lambda: quick_poly() // None, TypeError)
3094-
assert raises(lambda: quick_poly() % None, TypeError)
3095-
assert raises(lambda: divmod(quick_poly(), None), TypeError)
3096-
3097-
assert raises(lambda: None / quick_poly(), TypeError)
3098-
assert raises(lambda: None // quick_poly(), TypeError)
3099-
assert raises(lambda: None % quick_poly(), TypeError)
3100-
assert raises(lambda: divmod(None, quick_poly()), TypeError)
3101-
3102-
assert raises(lambda: quick_poly() / 0, ZeroDivisionError)
3103-
assert raises(lambda: quick_poly() // 0, ZeroDivisionError)
3104-
assert raises(lambda: quick_poly() % 0, ZeroDivisionError)
3105-
assert raises(lambda: divmod(quick_poly(), 0), ZeroDivisionError)
3106-
3107-
assert raises(lambda: 1 / P(ctx=ctx), ZeroDivisionError)
3108-
assert raises(lambda: 1 // P(ctx=ctx), ZeroDivisionError)
3109-
assert raises(lambda: 1 % P(ctx=ctx), ZeroDivisionError)
3110-
assert raises(lambda: divmod(1, P(ctx=ctx)), ZeroDivisionError)
3111-
3112-
assert raises(lambda: quick_poly() / P(ctx=ctx), ZeroDivisionError)
3113-
assert raises(lambda: quick_poly() // P(ctx=ctx), ZeroDivisionError)
3114-
assert raises(lambda: quick_poly() % P(ctx=ctx), ZeroDivisionError)
3115-
assert raises(lambda: divmod(quick_poly(), P(ctx=ctx)), ZeroDivisionError)
3116-
3117-
assert raises(lambda: quick_poly() / P(1, ctx=ctx1), IncompatibleContextError)
3118-
assert raises(lambda: quick_poly() // P(1, ctx=ctx1), IncompatibleContextError)
3119-
assert raises(lambda: quick_poly() % P(1, ctx=ctx1), IncompatibleContextError)
3120-
assert raises(lambda: divmod(quick_poly(), P(1, ctx=ctx1)), IncompatibleContextError)
3121-
31223093
assert f * g / mpoly({(0, 1): 1, (1, 0): 1}) \
31233094
== mpoly({(1, 1): 8, (0, 0): 2})
31243095

31253096
if not is_field:
31263097
assert raises(lambda: 1 / quick_poly(), DomainError)
31273098
assert raises(lambda: quick_poly() / P(2, ctx=ctx), DomainError)
31283099

3100+
# We prefer various other errors to the "division not supported" domain error so these are safe.
3101+
assert raises(lambda: quick_poly() / None, TypeError)
3102+
assert raises(lambda: quick_poly() // None, TypeError)
3103+
assert raises(lambda: quick_poly() % None, TypeError)
3104+
assert raises(lambda: divmod(quick_poly(), None), TypeError)
3105+
3106+
assert raises(lambda: None / quick_poly(), TypeError)
3107+
assert raises(lambda: None // quick_poly(), TypeError)
3108+
assert raises(lambda: None % quick_poly(), TypeError)
3109+
assert raises(lambda: divmod(None, quick_poly()), TypeError)
3110+
3111+
assert raises(lambda: quick_poly() / 0, ZeroDivisionError)
3112+
assert raises(lambda: quick_poly() // 0, ZeroDivisionError)
3113+
assert raises(lambda: quick_poly() % 0, ZeroDivisionError)
3114+
assert raises(lambda: divmod(quick_poly(), 0), ZeroDivisionError)
3115+
3116+
assert raises(lambda: 1 / P(ctx=ctx), ZeroDivisionError)
3117+
assert raises(lambda: 1 // P(ctx=ctx), ZeroDivisionError)
3118+
assert raises(lambda: 1 % P(ctx=ctx), ZeroDivisionError)
3119+
assert raises(lambda: divmod(1, P(ctx=ctx)), ZeroDivisionError)
3120+
3121+
assert raises(lambda: quick_poly() / P(ctx=ctx), ZeroDivisionError)
3122+
assert raises(lambda: quick_poly() // P(ctx=ctx), ZeroDivisionError)
3123+
assert raises(lambda: quick_poly() % P(ctx=ctx), ZeroDivisionError)
3124+
assert raises(lambda: divmod(quick_poly(), P(ctx=ctx)), ZeroDivisionError)
3125+
3126+
assert raises(lambda: quick_poly() / P(1, ctx=ctx1), IncompatibleContextError)
3127+
assert raises(lambda: quick_poly() // P(1, ctx=ctx1), IncompatibleContextError)
3128+
assert raises(lambda: quick_poly() % P(1, ctx=ctx1), IncompatibleContextError)
3129+
assert raises(lambda: divmod(quick_poly(), P(1, ctx=ctx1)), IncompatibleContextError)
3130+
31293131
assert quick_poly() ** 0 == P(1, ctx=ctx)
31303132
assert quick_poly() ** 1 == quick_poly()
31313133
assert quick_poly() ** 2 == mpoly({
@@ -3146,29 +3148,34 @@ def quick_poly():
31463148
# # XXX: Not sure what this should do in general:
31473149
assert raises(lambda: pow(P(1, ctx=ctx), 2, 3), NotImplementedError)
31483150

3149-
if (P is not flint.fmpz_mod_mpoly and P is not flint.nmod_mpoly) or f.context().is_prime():
3151+
if division_not_supported:
3152+
assert raises(lambda: (f * g).gcd(f), DomainError)
3153+
else:
31503154
if is_field:
31513155
assert (f * g).gcd(f) == f / 4
31523156
else:
31533157
assert (f * g).gcd(f) == f
31543158
assert raises(lambda: quick_poly().gcd(None), TypeError)
31553159
assert raises(lambda: quick_poly().gcd(P(ctx=ctx1)), IncompatibleContextError)
3156-
else:
3157-
assert raises(lambda: (f * g).gcd(f), DomainError)
31583160

3159-
if P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly:
3160-
if is_field:
3161-
assert (f * g).factor() == (S(8), [(mpoly({(0, 1): 1, (1, 0): 1}), 1), (f / 4, 1)])
3161+
if division_not_supported:
3162+
# Factorisation not allowed over Z/nZ for n not prime.
3163+
# Flint would abort so we raise an exception instead:
3164+
assert raises(lambda: (f * g).factor(), DomainError)
3165+
elif characteristic_zero:
3166+
# Primitive factors over Z for fmpz_mpoly and fmpq_mpoly
3167+
assert (f * g).factor() == (S(2), [(g / 2, 1), (f, 1)])
3168+
elif is_field:
3169+
# Monic polynomials over Z/pZ for nmod_mpoly and fmpz_mod_mpoly
3170+
assert (f * g).factor() == (S(8), [(g / 2, 1), (f / 4, 1)])
3171+
3172+
if division_not_supported:
3173+
assert raises(lambda: (f * g).sqrt(), DomainError)
31623174
else:
3163-
assert (f * g).factor() == (S(2), [(mpoly({(0, 1): 1, (1, 0): 1}), 1), (f, 1)])
3164-
3165-
if (P is not flint.fmpz_mod_mpoly and P is not flint.nmod_mpoly) or f.context().is_prime():
31663175
assert (f * f).sqrt() == f
31673176
if P is flint.fmpz_mpoly:
31683177
assert (f * f).sqrt(assume_perfect_square=True) == f
31693178
assert raises(lambda: quick_poly().sqrt(), ValueError)
3170-
else:
3171-
assert raises(lambda: (f * g).sqrt(), DomainError)
31723179

31733180
p = quick_poly()
31743181
assert p.derivative(0) == p.derivative("x0") == mpoly({(0, 0): 3, (1, 2): 8})

src/flint/types/fmpz_mod_mpoly.pyx

+6
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
831831
fmpz c
832832
fmpz_mod_mpoly u
833833

834+
if not self.ctx.is_prime():
835+
raise DomainError("factorisation with non-prime modulus is not supported")
836+
834837
fmpz_mod_mpoly_factor_init(fac, self.ctx.val)
835838
if not fmpz_mod_mpoly_factor(fac, self.val, self.ctx.val):
836839
raise RuntimeError("factorisation failed")
@@ -871,6 +874,9 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
871874
fmpz c
872875
fmpz_mod_mpoly u
873876

877+
if not self.ctx.is_prime():
878+
raise DomainError("factorisation with non-prime modulus is not supported")
879+
874880
fmpz_mod_mpoly_factor_init(fac, self.ctx.val)
875881
if not fmpz_mod_mpoly_factor_squarefree(fac, self.val, self.ctx.val):
876882
raise RuntimeError("factorisation failed")

0 commit comments

Comments
 (0)