Skip to content

Commit 5766f1d

Browse files
committed
[GH-15] Add test and fix patterns with bytes.
PullRequest: graalpython/168
2 parents 47d1816 + 271408d commit 5766f1d

File tree

5 files changed

+93
-69
lines changed

5 files changed

+93
-69
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_re.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ def test_special_re_compile():
6565
_dquote_re = re.compile(r'"(?:[^"\\]|\\.)*"')
6666

6767

68+
def test_json_bytes_re_compile():
69+
import json
70+
assert isinstance(json.encoder.HAS_UTF8.pattern, bytes)
71+
assert json.encoder.HAS_UTF8.search(b"\x80") is not None
72+
assert json.encoder.HAS_UTF8.search(b"space") is None
73+
try:
74+
json.encoder.HAS_UTF8.search("\x80")
75+
except TypeError as e:
76+
pass
77+
else:
78+
assert False, "searching a bytes-pattern in a str did not raise"
79+
80+
6881
class S(str):
6982
def __getitem__(self, index):
7083
return S(super().__getitem__(index))
@@ -297,4 +310,3 @@ def test_escaping(self):
297310
self.assertTrue(match)
298311
assert "frac" in match.groupdict()
299312
assert match.groupdict()["frac"] == "1"
300-

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SREModuleBuiltins.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,12 @@ Object call(TruffleObject callable, String arg1, int arg2,
292292
return doIt(callable, arg1, arg2, runtimeError, typeError, invokeNode);
293293
}
294294

295+
@SuppressWarnings("unused")
296+
@Fallback
297+
Object call(Object callable, Object arg1, Object arg2) {
298+
throw raise(RuntimeError);
299+
}
300+
295301
protected static Node createExecute() {
296302
return Message.EXECUTE.createNode();
297303
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/PythonObjectNativeWrapperMR.java

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@
8888
import com.oracle.graal.python.nodes.attributes.WriteAttributeToObjectNode;
8989
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
9090
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
91+
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
9192
import com.oracle.graal.python.nodes.object.GetClassNode;
9293
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
9394
import com.oracle.graal.python.runtime.PythonContext;
95+
import com.oracle.graal.python.runtime.PythonCore;
9496
import com.oracle.graal.python.runtime.exception.PException;
9597
import com.oracle.graal.python.runtime.interop.PythonMessageResolution;
9698
import com.oracle.graal.python.runtime.sequence.PSequence;
@@ -114,6 +116,7 @@
114116
import com.oracle.truffle.api.interop.UnsupportedTypeException;
115117
import com.oracle.truffle.api.nodes.Node;
116118
import com.oracle.truffle.api.nodes.UnexpectedResultException;
119+
import com.oracle.truffle.api.profiles.BranchProfile;
117120
import com.oracle.truffle.api.profiles.ConditionProfile;
118121
import com.oracle.truffle.api.profiles.ValueProfile;
119122

@@ -279,14 +282,33 @@ Object doTpAsNumber(PythonClass object, @SuppressWarnings("unused") String key)
279282
}
280283

281284
@Specialization(guards = "eq(TP_AS_BUFFER, key)")
282-
Object doTpAsBuffer(PythonClass object, @SuppressWarnings("unused") String key) {
283-
if (object == getCore().lookupType(PythonBuiltinClassType.PBytes) ||
284-
object == getCore().lookupType(PythonBuiltinClassType.PByteArray) ||
285-
object == getCore().lookupType(PythonBuiltinClassType.PMemoryView) ||
286-
object == getCore().lookupType(PythonBuiltinClassType.PBuffer)) {
287-
return new PyBufferProcsWrapper(object);
288-
}
289-
285+
Object doTpAsBuffer(PythonClass object, @SuppressWarnings("unused") String key,
286+
@Cached("create()") IsSubtypeNode isSubtype,
287+
@Cached("create()") BranchProfile notBytes,
288+
@Cached("create()") BranchProfile notBytearray,
289+
@Cached("create()") BranchProfile notMemoryview,
290+
@Cached("create()") BranchProfile notBuffer) {
291+
PythonCore core = getCore();
292+
PythonBuiltinClass pBytes = core.lookupType(PythonBuiltinClassType.PBytes);
293+
if (isSubtype.execute(object, pBytes)) {
294+
return new PyBufferProcsWrapper(pBytes);
295+
}
296+
notBytes.enter();
297+
PythonBuiltinClass pBytearray = core.lookupType(PythonBuiltinClassType.PByteArray);
298+
if (isSubtype.execute(object, pBytearray)) {
299+
return new PyBufferProcsWrapper(pBytearray);
300+
}
301+
notBytearray.enter();
302+
PythonBuiltinClass pMemoryview = core.lookupType(PythonBuiltinClassType.PMemoryView);
303+
if (isSubtype.execute(object, pMemoryview)) {
304+
return new PyBufferProcsWrapper(pMemoryview);
305+
}
306+
notMemoryview.enter();
307+
PythonBuiltinClass pBuffer = core.lookupType(PythonBuiltinClassType.PBuffer);
308+
if (isSubtype.execute(object, pBuffer)) {
309+
return new PyBufferProcsWrapper(pBuffer);
310+
}
311+
notBuffer.enter();
290312
// NULL pointer
291313
return getToSulongNode().execute(PNone.NO_VALUE);
292314
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/complex/PComplex.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import com.oracle.graal.python.builtins.objects.object.PythonBuiltinObject;
2929
import com.oracle.graal.python.builtins.objects.type.PythonClass;
30-
import com.oracle.truffle.api.CompilerAsserts;
30+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
3131

3232
public final class PComplex extends PythonBuiltinObject {
3333
/* Prime multiplier used in string and various other hashes in CPython. */
@@ -89,8 +89,8 @@ public double getImag() {
8989
}
9090

9191
@Override
92+
@TruffleBoundary
9293
public String toString() {
93-
CompilerAsserts.neverPartOfCompilation();
9494
if (Double.compare(real, 0.0) == 0) {
9595
return toString(imag) + "j";
9696
} else {

graalpython/lib-graalpython/_sre.py

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -161,34 +161,22 @@ def __tregex_compile(self, pattern):
161161
def __compile_cpython_sre(self):
162162
if not self.__compiled_sre_pattern:
163163
import _cpython_sre
164-
self.__compiled_sre_pattern = _cpython_sre.compile(self._emit(self.pattern), self.flags, self.code, self.num_groups, self.groupindex, self.indexgroup)
164+
self.__compiled_sre_pattern = _cpython_sre.compile(self.pattern, self.flags, self.code, self.num_groups, self.groupindex, self.indexgroup)
165165
return self.__compiled_sre_pattern
166166

167167

168-
def _decode_string(self, string, flags=0):
168+
def _decode_pattern(self, string, flags=0):
169169
if isinstance(string, str):
170+
# TODO: fix this in the regex engine
171+
pattern = string.replace(r'\"', '"').replace(r"\'", "'")
172+
173+
# TODO: that's not nearly complete but should be sufficient for now
174+
from sre_compile import SRE_FLAG_VERBOSE
175+
if flags & SRE_FLAG_VERBOSE:
176+
pattern = tregex_preprocess_for_verbose(pattern)
177+
return tregex_preprocess_default(pattern)
178+
else:
170179
return string
171-
elif isinstance(string, bytes):
172-
return string.decode()
173-
elif isinstance(string, bytearray):
174-
return string.decode()
175-
elif isinstance(string, memoryview):
176-
# return bytes(string).decode()
177-
raise TypeError("'memoryview' is currently unsupported as search pattern")
178-
raise TypeError("invalid search pattern {!r}".format(string))
179-
180-
181-
def _decode_pattern(self, string, flags=0):
182-
pattern = self._decode_string(string, flags)
183-
184-
# TODO: fix this in the regex engine
185-
pattern = pattern.replace(r'\"', '"').replace(r"\'", "'")
186-
187-
# TODO: that's not nearly complete but should be sufficient for now
188-
from sre_compile import SRE_FLAG_VERBOSE
189-
if flags & SRE_FLAG_VERBOSE:
190-
pattern = tregex_preprocess_for_verbose(pattern)
191-
return tregex_preprocess_default(pattern)
192180

193181

194182
def __repr__(self):
@@ -210,7 +198,6 @@ def __repr__(self):
210198

211199
def _search(self, pattern, string, pos, endpos):
212200
pattern = self.__tregex_compile(pattern)
213-
string = self._decode_string(string)
214201
if endpos == -1 or endpos >= len(string):
215202
result = tregex_call_safe(pattern.exec, string, pos)
216203
else:
@@ -227,29 +214,33 @@ def search(self, string, pos=0, endpos=None):
227214
return self.__compile_cpython_sre().search(string, pos, default(endpos, maxsize()))
228215

229216
def match(self, string, pos=0, endpos=None):
230-
try:
231-
if not self.pattern.startswith("^"):
232-
return self._search("^" + self.pattern, string, pos, default(endpos, -1))
233-
else:
234-
return self._search(self.pattern, string, pos, default(endpos, -1))
235-
except RuntimeError:
236-
return self.__compile_cpython_sre().match(string, pos, default(endpos, maxsize()))
217+
pattern = self.pattern
218+
if isinstance(pattern, str):
219+
try:
220+
if not pattern.startswith("^"):
221+
return self._search("^" + pattern, string, pos, default(endpos, -1))
222+
else:
223+
return self._search(pattern, string, pos, default(endpos, -1))
224+
except RuntimeError:
225+
pass
226+
return self.__compile_cpython_sre().match(string, pos, default(endpos, maxsize()))
237227

238228
def fullmatch(self, string, pos=0, endpos=None):
239-
try:
240-
pattern = self.pattern
241-
if not pattern.startswith("^"):
242-
pattern = "^" + pattern
243-
if not pattern.endswith("$"):
244-
pattern = pattern + "$"
245-
return self._search(pattern, string, pos, default(endpos, -1))
246-
except RuntimeError:
247-
return self.__compile_cpython_sre().fullmatch(string, pos, default(endpos, maxsize()))
229+
pattern = self.pattern
230+
if isinstance(pattern, str):
231+
try:
232+
if not pattern.startswith("^"):
233+
pattern = "^" + pattern
234+
if not pattern.endswith("$"):
235+
pattern = pattern + "$"
236+
return self._search(pattern, string, pos, default(endpos, -1))
237+
except RuntimeError:
238+
pass
239+
return self.__compile_cpython_sre().fullmatch(string, pos, default(endpos, maxsize()))
248240

249241
def findall(self, string, pos=0, endpos=-1):
250242
try:
251243
pattern = self.__tregex_compile(self.pattern)
252-
string = self._decode_string(string)
253244
if endpos > len(string):
254245
endpos = len(string)
255246
elif endpos < 0:
@@ -281,9 +272,9 @@ def group(match_result, group_nr, string):
281272
return string[group_start:group_end]
282273

283274
n = len(repl)
284-
result = self._emit("")
275+
result = ""
285276
start = 0
286-
backslash = self._emit('\\')
277+
backslash = '\\'
287278
pos = repl.find(backslash, start)
288279
while pos != -1 and start < n:
289280
if pos+1 < n:
@@ -292,15 +283,15 @@ def group(match_result, group_nr, string):
292283
group_str = group(match_result, group_nr, string)
293284
if group_str is None:
294285
raise ValueError("invalid group reference %s at position %s" % (group_nr, pos))
295-
result += repl[start:pos] + self._emit(group_str)
286+
result += repl[start:pos] + group_str
296287
start = pos + 2
297288
elif repl[pos + 1] == 'g':
298289
group_ref, group_ref_end, digits_only = self.__extract_groupname(repl, pos + 2)
299290
if group_ref:
300291
group_str = group(match_result, int(group_ref) if digits_only else pattern.groups[group_ref], string)
301292
if group_str is None:
302293
raise ValueError("invalid group reference %s at position %s" % (group_ref, pos))
303-
result += repl[start:pos] + self._emit(group_str)
294+
result += repl[start:pos] + group_str
304295
start = group_ref_end + 1
305296
elif repl[pos + 1] == backslash:
306297
result += repl[start:pos] + backslash
@@ -331,40 +322,33 @@ def sub(self, repl, string, count=0):
331322
n = 0
332323
try:
333324
pattern = self.__tregex_compile(self.pattern)
334-
decoded_string = self._decode_string(string)
335325
result = []
336326
pos = 0
337327
is_string_rep = isinstance(repl, str) or isinstance(repl, bytes) or isinstance(repl, bytearray)
338328
if is_string_rep:
339329
repl = _process_escape_sequences(repl)
340330
progress = True
341-
while (count == 0 or n < count) and pos <= len(decoded_string) and progress:
342-
match_result = tregex_call_safe(pattern.exec, decoded_string, pos)
331+
while (count == 0 or n < count) and pos <= len(string) and progress:
332+
match_result = tregex_call_safe(pattern.exec, string, pos)
343333
if not match_result.isMatch:
344334
break
345335
n += 1
346336
start = match_result.start[0]
347337
end = match_result.end[0]
348-
result.append(self._emit(decoded_string[pos:start]))
338+
result.append(string[pos:start])
349339
if is_string_rep:
350-
result.append(self.__replace_groups(repl, decoded_string, match_result, pattern))
340+
result.append(self.__replace_groups(repl, string, match_result, pattern))
351341
else:
352342
_srematch = SRE_Match(self, pos, -1, match_result)
353343
_repl = repl(_srematch)
354344
result.append(_repl)
355345
pos = end
356346
progress = (start != end)
357-
result.append(self._emit(decoded_string[pos:]))
358-
return self._emit("").join(result)
347+
result.append(string[pos:])
348+
return "".join(result)
359349
except BaseException:
360350
return self.__compile_cpython_sre().sub(repl, string, count)
361351

362-
def _emit(self, str_like_obj):
363-
assert isinstance(str_like_obj, str) or isinstance(str_like_obj, bytes)
364-
if self.__was_bytes != isinstance(str_like_obj, bytes):
365-
return str_like_obj.encode()
366-
return str_like_obj
367-
368352

369353
compile = SRE_Pattern
370354

0 commit comments

Comments
 (0)