Skip to content

Commit 5248a5a

Browse files
committed
[GR-13351] Missing specialization [String,String,Long,PNone] in StartsWithNode.
PullRequest: graalpython/366
2 parents 0bf72b7 + d361df7 commit 5248a5a

File tree

4 files changed

+235
-21
lines changed

4 files changed

+235
-21
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_string.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, Oracle and/or its affiliates.
1+
# Copyright (c) 2018, 2019, Oracle and/or its affiliates.
22
# Copyright (C) 1996-2017 Python Software Foundation
33
#
44
# Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
@@ -881,6 +881,56 @@ def test_count(self):
881881
self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i))
882882
self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i))
883883

884+
def test_startswith(self):
885+
self.checkequal(True, 'hello', 'startswith', 'he')
886+
self.checkequal(True, 'hello', 'startswith', 'hello')
887+
self.checkequal(False, 'hello', 'startswith', 'hello world')
888+
self.checkequal(True, 'hello', 'startswith', '')
889+
self.checkequal(False, 'hello', 'startswith', 'ello')
890+
self.checkequal(True, 'hello', 'startswith', 'ello', 1)
891+
self.checkequal(True, 'hello', 'startswith', 'o', 4)
892+
self.checkequal(False, 'hello', 'startswith', 'o', 5)
893+
self.checkequal(True, 'hello', 'startswith', '', 5)
894+
self.checkequal(False, 'hello', 'startswith', 'lo', 6)
895+
self.checkequal(True, 'helloworld', 'startswith', 'lowo', 3)
896+
self.checkequal(True, 'helloworld', 'startswith', 'lowo', 3, 7)
897+
self.checkequal(False, 'helloworld', 'startswith', 'lowo', 3, 6)
898+
self.checkequal(True, '', 'startswith', '', 0, 1)
899+
self.checkequal(True, '', 'startswith', '', 0, 0)
900+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
901+
self.checkequal(False, '', 'startswith', '', 1, 0)
902+
903+
# test negative indices
904+
self.checkequal(True, 'hello', 'startswith', 'he', 0, -1)
905+
self.checkequal(True, 'hello', 'startswith', 'he', -53, -1)
906+
self.checkequal(False, 'hello', 'startswith', 'hello', 0, -1)
907+
self.checkequal(False, 'hello', 'startswith', 'hello world', -1, -10)
908+
self.checkequal(False, 'hello', 'startswith', 'ello', -5)
909+
self.checkequal(True, 'hello', 'startswith', 'ello', -4)
910+
self.checkequal(False, 'hello', 'startswith', 'o', -2)
911+
self.checkequal(True, 'hello', 'startswith', 'o', -1)
912+
self.checkequal(True, 'hello', 'startswith', '', -3, -3)
913+
self.checkequal(False, 'hello', 'startswith', 'lo', -9)
914+
915+
self.checkraises(TypeError, 'hello', 'startswith')
916+
#self.checkraises(TypeError, 'hello', 'startswith', 42)
917+
918+
# test tuple arguments
919+
self.checkequal(True, 'hello', 'startswith', ('he', 'ha'))
920+
self.checkequal(False, 'hello', 'startswith', ('lo', 'llo'))
921+
self.checkequal(True, 'hello', 'startswith', ('hellox', 'hello'))
922+
self.checkequal(False, 'hello', 'startswith', ())
923+
self.checkequal(True, 'helloworld', 'startswith', ('hellowo',
924+
'rld', 'lowo'), 3)
925+
self.checkequal(False, 'helloworld', 'startswith', ('hellowo', 'ello',
926+
'rld'), 3)
927+
self.checkequal(True, 'hello', 'startswith', ('lo', 'he'), 0, -1)
928+
self.checkequal(False, 'hello', 'startswith', ('he', 'hel'), 0, 1)
929+
self.checkequal(True, 'hello', 'startswith', ('he', 'hel'), 0, 2)
930+
931+
self.checkraises(TypeError, 'hello', 'startswith', (42,))
932+
self.checkequal(True, 'hello', 'startswith', ('he', 42))
933+
self.checkraises(TypeError, 'hello', 'startswith', ('ne', 42,))
884934

885935
def test_same_id():
886936
empty_ids = set([id(str()) for i in range(100)])

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/str/StringBuiltins.java

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
import com.oracle.truffle.api.dsl.TypeSystemReference;
105105
import com.oracle.truffle.api.profiles.BranchProfile;
106106
import com.oracle.truffle.api.profiles.ConditionProfile;
107+
import java.math.BigInteger;
107108

108109
@CoreFunctions(extendClasses = PythonBuiltinClassType.PString)
109110
public final class StringBuiltins extends PythonBuiltins {
@@ -436,55 +437,171 @@ public abstract static class RAddNode extends AddNode {
436437
}
437438

438439
// str.startswith(prefix[, start[, end]])
439-
@Builtin(name = "startswith", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 5)
440+
@Builtin(name = "startswith", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
440441
@TypeSystemReference(PythonArithmeticTypes.class)
441442
@GenerateNodeFactory
442443
public abstract static class StartsWithNode extends PythonBuiltinNode {
443-
@Specialization
444-
boolean startsWith(String self, String prefix, int start, int end) {
444+
445+
private @Child CastToIndexNode startNode;
446+
private @Child CastToIndexNode endNode;
447+
448+
private CastToIndexNode getStartNode() {
449+
if (startNode == null) {
450+
CompilerDirectives.transferToInterpreterAndInvalidate();
451+
startNode = insert(CastToIndexNode.create(TypeError, val -> {
452+
throw raise(PythonBuiltinClassType.TypeError, "slice indices must be integers or None or have an __index__ method");
453+
}));
454+
}
455+
return startNode;
456+
}
457+
458+
private CastToIndexNode getEndNode() {
459+
if (endNode == null) {
460+
CompilerDirectives.transferToInterpreterAndInvalidate();
461+
endNode = insert(CastToIndexNode.create(TypeError, val -> {
462+
throw raise(PythonBuiltinClassType.TypeError, "slice indices must be integers or None or have an __index__ method");
463+
}));
464+
}
465+
return endNode;
466+
}
467+
468+
@TruffleBoundary
469+
private static int correctIndex(PInt index, String text) {
470+
int textLength = text.length();
471+
BigInteger bIndex = index.getValue();
472+
BigInteger bTextLength = BigInteger.valueOf(textLength);
473+
if (bIndex.compareTo(BigInteger.ZERO) < 0) {
474+
BigInteger result = bIndex.add(bTextLength);
475+
return result.compareTo(BigInteger.valueOf(Integer.MIN_VALUE)) < 0 ? Integer.MIN_VALUE : result.intValue();
476+
}
477+
return bIndex.compareTo(bTextLength) > 0 ? textLength : bIndex.intValue();
478+
}
479+
480+
private static int correctIndex(int index, String text) {
481+
return index < 0 ? index + text.length() : index;
482+
}
483+
484+
private static int correctIndex(long index, String text) {
485+
if (index < 0) {
486+
long result = index + text.length();
487+
return result < Integer.MIN_VALUE ? Integer.MIN_VALUE : (int) result;
488+
}
489+
return index > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) index;
490+
}
491+
492+
private static boolean doIt(String text, String prefix, int start, int end) {
445493
if (end - start < prefix.length()) {
446494
return false;
447-
} else if (self.startsWith(prefix, start)) {
448-
return true;
449495
}
450-
return false;
496+
return text.startsWith(prefix, start < 0 ? 0 : start);
451497
}
452498

453-
@Specialization
454-
boolean startsWith(String self, PTuple prefix, int start, int end) {
499+
private boolean doIt(String self, PTuple prefix, int start, int end) {
455500
for (Object o : prefix.getArray()) {
456501
if (o instanceof String) {
457-
if (startsWith(self, (String) o, start, end)) {
502+
if (doIt(self, (String) o, start, end)) {
458503
return true;
459504
}
460505
} else if (o instanceof PString) {
461-
if (startsWith(self, ((PString) o).getValue(), start, end)) {
506+
if (doIt(self, ((PString) o).getValue(), start, end)) {
462507
return true;
463508
}
509+
} else {
510+
throw raise(TypeError, "tuple for startswith must only contain str, not %p", o);
464511
}
465512
}
466513
return false;
467514
}
468515

516+
@Specialization
517+
boolean startsWith(String self, String prefix, int start, int end) {
518+
return doIt(self, prefix, correctIndex(start, self), correctIndex(end, self));
519+
}
520+
521+
@Specialization
522+
boolean startsWith(String self, PTuple prefix, int start, int end) {
523+
return doIt(self, prefix, correctIndex(start, self), correctIndex(end, self));
524+
}
525+
469526
@Specialization
470527
boolean startsWith(String self, String prefix, int start, @SuppressWarnings("unused") PNone end) {
471-
return startsWith(self, prefix, start, self.length());
528+
return doIt(self, prefix, correctIndex(start, self), self.length());
529+
}
530+
531+
@Specialization
532+
boolean startsWith(String self, String prefix, long start, @SuppressWarnings("unused") PNone end) {
533+
return doIt(self, prefix, correctIndex(start, self), self.length());
534+
}
535+
536+
@Specialization
537+
boolean startsWith(String self, String prefix, long start, long end) {
538+
return doIt(self, prefix, correctIndex(start, self), correctIndex(end, self));
539+
}
540+
541+
@Specialization(rewriteOn = ArithmeticException.class)
542+
boolean startsWith(String self, String prefix, PInt start, @SuppressWarnings("unused") PNone end) {
543+
return startsWith(self, prefix, start.intValueExact(), self.length());
544+
}
545+
546+
@Specialization
547+
boolean startsWithPIntOvf(String self, String prefix, PInt start, @SuppressWarnings("unused") PNone end) {
548+
return doIt(self, prefix, correctIndex(start, self), self.length());
472549
}
473550

474551
@Specialization
475552
boolean startsWith(String self, String prefix, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end) {
476-
return startsWith(self, prefix, 0, self.length());
553+
return doIt(self, prefix, 0, self.length());
477554
}
478555

479556
@Specialization
480557
boolean startsWith(String self, PTuple prefix, int start, @SuppressWarnings("unused") PNone end) {
481-
return startsWith(self, prefix, start, self.length());
558+
return doIt(self, prefix, correctIndex(start, self), self.length());
559+
}
560+
561+
@Specialization
562+
boolean startsWith(String self, PTuple prefix, long start, @SuppressWarnings("unused") PNone end) {
563+
return doIt(self, prefix, correctIndex(start, self), self.length());
564+
}
565+
566+
@Specialization
567+
boolean startsWith(String self, PTuple prefix, long start, long end) {
568+
return doIt(self, prefix, correctIndex(start, self), correctIndex(end, self));
569+
}
570+
571+
@Specialization(rewriteOn = ArithmeticException.class)
572+
boolean startsWith(String self, PTuple prefix, PInt start, @SuppressWarnings("unused") PNone end) {
573+
return startsWith(self, prefix, start.intValueExact(), end);
574+
}
575+
576+
@Specialization
577+
boolean startsWithPIntOvf(String self, PTuple prefix, PInt start, @SuppressWarnings("unused") PNone end) {
578+
return doIt(self, prefix, correctIndex(start, self), self.length());
482579
}
483580

484581
@Specialization
485582
boolean startsWith(String self, PTuple prefix, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end) {
486583
return startsWith(self, prefix, 0, self.length());
487584
}
585+
586+
@Specialization
587+
boolean startsWith(String self, String prefix, Object start, Object end) {
588+
int sIndex = getStartNode().execute(start);
589+
int eIndex = getEndNode().execute(end);
590+
return doIt(self, prefix, correctIndex(sIndex, self), correctIndex(eIndex, self));
591+
}
592+
593+
@Specialization
594+
boolean startsWith(String self, PTuple prefix, Object start, Object end) {
595+
int sIndex = getStartNode().execute(start);
596+
int eIndex = getEndNode().execute(end);
597+
return doIt(self, prefix, correctIndex(sIndex, self), correctIndex(eIndex, self));
598+
}
599+
600+
@Fallback
601+
@SuppressWarnings("unused")
602+
boolean general(Object self, Object prefix, Object start, Object end) {
603+
throw raise(TypeError, "startswith first arg must be str or a tuple of str, not %p", prefix);
604+
}
488605
}
489606

490607
// str.endswith(suffix[, start[, end]])

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/util/CastToIndexNode.java

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -46,16 +46,21 @@
4646
import static com.oracle.graal.python.nodes.SpecialMethodNames.__INDEX__;
4747

4848
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
49+
import com.oracle.graal.python.builtins.objects.PNone;
4950
import com.oracle.graal.python.builtins.objects.ints.PInt;
5051
import com.oracle.graal.python.nodes.PNodeWithContext;
5152
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
53+
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
5254
import com.oracle.truffle.api.CompilerDirectives;
5355
import com.oracle.truffle.api.dsl.Fallback;
5456
import com.oracle.truffle.api.dsl.Specialization;
57+
import com.oracle.truffle.api.dsl.TypeSystemReference;
58+
import java.util.function.Function;
5559

5660
/**
5761
* Converts an arbitrary object to an index-sized integer (which is a Java {@code int}).
5862
*/
63+
@TypeSystemReference(PythonArithmeticTypes.class)
5964
public abstract class CastToIndexNode extends PNodeWithContext {
6065

6166
private static final String ERROR_MESSAGE = "cannot fit 'int' into an index-sized integer";
@@ -65,10 +70,12 @@ public abstract class CastToIndexNode extends PNodeWithContext {
6570

6671
private final PythonBuiltinClassType errorType;
6772
private final boolean recursive;
73+
private final Function<Object, Integer> typeErrorHandler;
6874

69-
protected CastToIndexNode(PythonBuiltinClassType errorType, boolean recursive) {
75+
protected CastToIndexNode(PythonBuiltinClassType errorType, boolean recursive, Function<Object, Integer> typeErrorHandler) {
7076
this.errorType = errorType;
7177
this.recursive = recursive;
78+
this.typeErrorHandler = typeErrorHandler;
7279
}
7380

7481
public abstract int execute(Object x);
@@ -117,27 +124,49 @@ int doPIntOvf(PInt x) {
117124
}
118125
}
119126

127+
@Specialization
128+
public int toInt(double x) {
129+
if (typeErrorHandler != null) {
130+
return typeErrorHandler.apply(x);
131+
}
132+
throw raise(TypeError, "'%p' object cannot be interpreted as an integer", x);
133+
}
134+
120135
@Fallback
121136
int doGeneric(Object x) {
122137
if (recursive) {
123138
if (callIndexNode == null) {
124139
CompilerDirectives.transferToInterpreterAndInvalidate();
125140
callIndexNode = insert(LookupAndCallUnaryNode.create(__INDEX__));
126141
}
142+
Object result = callIndexNode.executeObject(x);
143+
if (result == PNone.NONE) {
144+
if (typeErrorHandler != null) {
145+
return typeErrorHandler.apply(x);
146+
}
147+
throw raise(TypeError, "'%p' object cannot be interpreted as an integer", x);
148+
}
127149
if (recursiveNode == null) {
128150
CompilerDirectives.transferToInterpreterAndInvalidate();
129-
recursiveNode = insert(CastToIndexNodeGen.create(errorType, false));
151+
recursiveNode = insert(CastToIndexNodeGen.create(errorType, false, typeErrorHandler));
130152
}
131153
return recursiveNode.execute(callIndexNode.executeObject(x));
132154
}
155+
if (typeErrorHandler != null) {
156+
return typeErrorHandler.apply(x);
157+
}
133158
throw raise(TypeError, "__index__ returned non-int (type %p)", x);
134159
}
135160

136161
public static CastToIndexNode create() {
137-
return CastToIndexNodeGen.create(IndexError, true);
162+
return CastToIndexNodeGen.create(IndexError, true, null);
138163
}
139164

140165
public static CastToIndexNode createOverflow() {
141-
return CastToIndexNodeGen.create(OverflowError, true);
166+
return CastToIndexNodeGen.create(OverflowError, true, null);
167+
}
168+
169+
public static CastToIndexNode create(PythonBuiltinClassType errorType, Function<Object, Integer> typeErrorHandler) {
170+
return CastToIndexNodeGen.create(errorType, true, typeErrorHandler);
142171
}
143172
}

0 commit comments

Comments
 (0)