Skip to content

Commit 0bf72b7

Browse files
committed
[GR-13482] There are missing specializations for PByteArray in B2aBase64Node.
PullRequest: graalpython/375
2 parents ed7c93f + 494b45e commit 0bf72b7

File tree

3 files changed

+202
-13
lines changed

3 files changed

+202
-13
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2019, Oracle and/or its affiliates.
2+
# Copyright (C) 1996-2017 Python Software Foundation
3+
#
4+
# Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
5+
6+
import unittest
7+
import binascii
8+
import array
9+
import sys
10+
11+
class MyInt():
12+
def __init__(self, value):
13+
self.value = value
14+
15+
def __int__(self):
16+
return self.value
17+
18+
class BinASCIITest(unittest.TestCase):
19+
20+
type2test = bytes
21+
22+
def test_b2a_base64_newline(self):
23+
b = self.type2test(b'hello')
24+
self.assertEqual(binascii.b2a_base64(b),
25+
b'aGVsbG8=\n')
26+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
27+
self.assertEqual(binascii.b2a_base64(b, newline=True),
28+
b'aGVsbG8=\n')
29+
self.assertEqual(binascii.b2a_base64(b, newline=False),
30+
b'aGVsbG8=')
31+
32+
def test_b2a_base64_int_newline(self):
33+
b = self.type2test(b'hello')
34+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
35+
self.assertEqual(binascii.b2a_base64(b, newline=125),
36+
b'aGVsbG8=\n')
37+
self.assertEqual(binascii.b2a_base64(b, newline=-10),
38+
b'aGVsbG8=\n')
39+
self.assertEqual(binascii.b2a_base64(b, newline=0),
40+
b'aGVsbG8=')
41+
42+
def test_b2a_base64_object_newline(self):
43+
b = self.type2test(b'hello')
44+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
45+
self.assertEqual(binascii.b2a_base64(b, newline=MyInt(125)),
46+
b'aGVsbG8=\n')
47+
self.assertEqual(binascii.b2a_base64(b, newline=MyInt(-10)),
48+
b'aGVsbG8=\n')
49+
self.assertEqual(binascii.b2a_base64(b, newline=MyInt(0)),
50+
b'aGVsbG8=')
51+
52+
def test_b2a_base64_wrong_newline(self):
53+
b = self.type2test(b'hello')
54+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
55+
self.assertRaises(TypeError, binascii.b2a_base64, b, newline='ahoj')
56+
57+
def test_b2a_base64_return_type(self):
58+
b = self.type2test(b'hello')
59+
self.assertEqual(type(binascii.b2a_base64(b)), bytes)
60+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
61+
self.assertEqual(type(binascii.b2a_base64(b, newline=False)), bytes)
62+
63+
class ArrayBinASCIITest(BinASCIITest):
64+
def type2test(self, s):
65+
return array.array('b', list(s))
66+
67+
68+
class BytearrayBinASCIITest(BinASCIITest):
69+
type2test = bytearray
70+
71+
72+
class MemoryviewBinASCIITest(BinASCIITest):
73+
type2test = memoryview
74+
75+
class IndependetTest(unittest.TestCase):
76+
77+
def test_b2a_base64_wrong_first_arg(self):
78+
if (sys.version_info.major >= 3 and sys.version_info.minor >= 6):
79+
self.assertRaises(TypeError, binascii.b2a_base64, 'Ahoj', newline=True)
80+
self.assertRaises(TypeError, binascii.b2a_base64, 10, newline=True)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BinasciiModuleBuiltins.java

Lines changed: 121 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -51,24 +51,34 @@
5151
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
5252
import com.oracle.graal.python.builtins.PythonBuiltins;
5353
import com.oracle.graal.python.builtins.objects.PNone;
54+
import com.oracle.graal.python.builtins.objects.array.PArray;
5455
import com.oracle.graal.python.builtins.objects.bytes.BytesNodes;
5556
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
5657
import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
5758
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
59+
import com.oracle.graal.python.builtins.objects.ints.PInt;
60+
import com.oracle.graal.python.builtins.objects.memoryview.PMemoryView;
5861
import com.oracle.graal.python.builtins.objects.module.PythonModule;
5962
import com.oracle.graal.python.builtins.objects.type.PythonClass;
6063
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
64+
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
6165
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
6266
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
6367
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
68+
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
69+
import com.oracle.graal.python.nodes.util.CastToIntegerFromIntNode;
6470
import com.oracle.graal.python.runtime.PythonCore;
6571
import com.oracle.graal.python.runtime.exception.PException;
72+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.SystemError;
6673
import com.oracle.truffle.api.CompilerDirectives;
6774
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6875
import com.oracle.truffle.api.dsl.Cached;
76+
import com.oracle.truffle.api.dsl.Fallback;
6977
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
7078
import com.oracle.truffle.api.dsl.NodeFactory;
7179
import com.oracle.truffle.api.dsl.Specialization;
80+
import com.oracle.truffle.api.dsl.TypeSystemReference;
81+
import com.oracle.truffle.api.profiles.ConditionProfile;
7282
import com.sun.org.apache.xerces.internal.impl.dv.util.Base64;
7383

7484
@CoreFunctions(defineModule = "binascii")
@@ -163,26 +173,124 @@ private ReadAttributeFromObjectNode getAttrNode() {
163173
}
164174

165175
@Builtin(name = "b2a_base64", fixedNumOfPositionalArgs = 1, keywordArguments = {"newline"})
176+
@TypeSystemReference(PythonArithmeticTypes.class)
166177
@GenerateNodeFactory
167178
static abstract class B2aBase64Node extends PythonBinaryBuiltinNode {
168-
@Specialization(guards = "isNoValue(newline)")
169-
@TruffleBoundary
170-
String b2a(PBytes data, @SuppressWarnings("unused") PNone newline,
171-
@Cached("create()") SequenceStorageNodes.ToByteArrayNode toByteArray) {
172-
return b2a(data, true, toByteArray);
179+
180+
@Child private SequenceStorageNodes.ToByteArrayNode toByteArray;
181+
@Child private CastToIntegerFromIntNode castToIntNode;
182+
@Child private B2aBase64Node recursiveNode;
183+
184+
private SequenceStorageNodes.ToByteArrayNode getToByteArrayNode() {
185+
if (toByteArray == null) {
186+
CompilerDirectives.transferToInterpreterAndInvalidate();
187+
toByteArray = insert(SequenceStorageNodes.ToByteArrayNode.create());
188+
}
189+
return toByteArray;
190+
}
191+
192+
private CastToIntegerFromIntNode getCastToIntNode() {
193+
if (castToIntNode == null) {
194+
CompilerDirectives.transferToInterpreterAndInvalidate();
195+
castToIntNode = insert(CastToIntegerFromIntNode.create(val -> {
196+
throw raise(PythonBuiltinClassType.TypeError, "an integer is required (got type %p)", val);
197+
}));
198+
}
199+
return castToIntNode;
200+
}
201+
202+
private B2aBase64Node getRecursiveNode() {
203+
if (recursiveNode == null) {
204+
CompilerDirectives.transferToInterpreterAndInvalidate();
205+
recursiveNode = insert(BinasciiModuleBuiltinsFactory.B2aBase64NodeFactory.create());
206+
}
207+
return recursiveNode;
173208
}
174209

175-
@Specialization
176210
@TruffleBoundary
177-
String b2a(PBytes data, boolean newline,
178-
@Cached("create()") SequenceStorageNodes.ToByteArrayNode toByteArray) {
179-
String encode = Base64.encode(toByteArray.execute(data.getSequenceStorage()));
211+
private PBytes b2a(byte[] data, boolean newline) {
212+
String encode = Base64.encode(data);
180213
if (newline) {
181-
return encode + "\n";
182-
} else {
183-
return encode;
214+
return factory().createBytes((encode + "\n").getBytes());
184215
}
216+
return factory().createBytes((encode).getBytes());
217+
}
218+
219+
@Specialization(guards = "isNoValue(newline)")
220+
PBytes b2aBytesLike(PIBytesLike data, @SuppressWarnings("unused") PNone newline) {
221+
return b2aBytesLike(data, 1);
222+
}
223+
224+
@Specialization
225+
PBytes b2aBytesLike(PIBytesLike data, long newline) {
226+
return b2a(getToByteArrayNode().execute(data.getSequenceStorage()), newline != 0);
227+
}
228+
229+
@Specialization
230+
PBytes b2aBytesLike(PIBytesLike data, PInt newline) {
231+
return b2a(getToByteArrayNode().execute(data.getSequenceStorage()), !newline.isZero());
232+
}
233+
234+
@Specialization
235+
PBytes b2aBytesLike(PIBytesLike data, Object newline) {
236+
return (PBytes) getRecursiveNode().execute(data, getCastToIntNode().execute(newline));
185237
}
238+
239+
@Specialization(guards = "isNoValue(newline)")
240+
PBytes b2aArray(PArray data, @SuppressWarnings("unused") PNone newline) {
241+
return b2aArray(data, 1);
242+
}
243+
244+
@Specialization
245+
PBytes b2aArray(PArray data, long newline) {
246+
return b2a(getToByteArrayNode().execute(data.getSequenceStorage()), newline != 0);
247+
}
248+
249+
@Specialization
250+
PBytes b2aArray(PArray data, PInt newline) {
251+
return b2a(getToByteArrayNode().execute(data.getSequenceStorage()), !newline.isZero());
252+
}
253+
254+
@Specialization
255+
PBytes b2aArray(PArray data, Object newline) {
256+
return (PBytes) getRecursiveNode().execute(data, getCastToIntNode().execute(newline));
257+
}
258+
259+
@Specialization(guards = "isNoValue(newline)")
260+
PBytes b2aMmeory(PMemoryView data, @SuppressWarnings("unused") PNone newline,
261+
@Cached("create(TOBYTES)") LookupAndCallUnaryNode toBytesNode,
262+
@Cached("createBinaryProfile()") ConditionProfile isBytesProfile) {
263+
return b2aMemory(data, 1, toBytesNode, isBytesProfile);
264+
}
265+
266+
@Specialization
267+
PBytes b2aMemory(PMemoryView data, long newline,
268+
@Cached("create(TOBYTES)") LookupAndCallUnaryNode toBytesNode,
269+
@Cached("createBinaryProfile()") ConditionProfile isBytesProfile) {
270+
Object bytesObj = toBytesNode.executeObject(data);
271+
if (isBytesProfile.profile(bytesObj instanceof PBytes)) {
272+
return b2aBytesLike((PBytes) bytesObj, newline);
273+
}
274+
throw raise(SystemError, "could not get bytes of memoryview");
275+
}
276+
277+
@Specialization
278+
PBytes b2aMmeory(PMemoryView data, PInt newline,
279+
@Cached("create(TOBYTES)") LookupAndCallUnaryNode toBytesNode,
280+
@Cached("createBinaryProfile()") ConditionProfile isBytesProfile) {
281+
return b2aMemory(data, newline.isZero() ? 0 : 1, toBytesNode, isBytesProfile);
282+
}
283+
284+
@Specialization
285+
PBytes b2aMmeory(PMemoryView data, Object newline) {
286+
return (PBytes) getRecursiveNode().execute(data, getCastToIntNode().execute(newline));
287+
}
288+
289+
@Fallback
290+
PBytes b2sGeneral(Object data, @SuppressWarnings("unused") Object newline) {
291+
throw raise(PythonBuiltinClassType.TypeError, "a bytes-like object is required, not '%p'", data);
292+
}
293+
186294
}
187295

188296
@Builtin(name = "b2a_hex", fixedNumOfPositionalArgs = 1)

mx.graalpython/copyrights/overrides

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ graalpython/com.oracle.graal.python.test/src/tests/seq_tests.py,python.copyright
195195
graalpython/com.oracle.graal.python.test/src/tests/slice-test.py,zippy.copyright
196196
graalpython/com.oracle.graal.python.test/src/tests/test_basicobject_pressure.py,zippy.copyright
197197
graalpython/com.oracle.graal.python.test/src/tests/test_binary_arithmetic.py,zippy.copyright
198+
graalpython/com.oracle.graal.python.test/src/tests/test_binascii.py,python.copyright
198199
graalpython/com.oracle.graal.python.test/src/tests/test_bisect-right.py,zippy.copyright
199200
graalpython/com.oracle.graal.python.test/src/tests/test_builtin-list-intrinsification.py,zippy.copyright
200201
graalpython/com.oracle.graal.python.test/src/tests/test_call-bimorphic.py,zippy.copyright

0 commit comments

Comments
 (0)