Skip to content

Commit ed86005

Browse files
committed
Use PyCryptodome instead of PyCrypto
1 parent b3af499 commit ed86005

File tree

4 files changed

+32
-54
lines changed

4 files changed

+32
-54
lines changed

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
PyCrypto >= 2.1
1+
PyCryptodome

src/potr/compatcrypto/pycrypto.py

+22-44
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@
1515
# You should have received a copy of the GNU Lesser General Public License
1616
# along with this library. If not, see <http://www.gnu.org/licenses/>.
1717

18-
try:
19-
import Crypto
20-
except ImportError:
21-
import crypto as Crypto
22-
23-
from Crypto import Cipher
24-
from Crypto.Hash import SHA256 as _SHA256
25-
from Crypto.Hash import SHA as _SHA1
26-
from Crypto.Hash import HMAC as _HMAC
27-
from Crypto.PublicKey import DSA
28-
import Crypto.Random.random
18+
from Cryptodome import Cipher
19+
from Cryptodome.Hash import HMAC as _HMAC
20+
from Cryptodome.Hash import SHA256 as _SHA256
21+
from Cryptodome.Hash import SHA as _SHA1
22+
from Cryptodome.PublicKey import DSA
23+
from Cryptodome.Random import random
24+
from Cryptodome.Signature import DSS
25+
from Cryptodome.Util import Counter
26+
2927
from numbers import Number
3028

3129
from potr.compatcrypto import common
@@ -45,36 +43,14 @@ def SHA256HMAC(key, data):
4543

4644
def AESCTR(key, counter=0):
4745
if isinstance(counter, Number):
48-
counter = Counter(counter)
49-
if not isinstance(counter, Counter):
46+
counter = Counter.new(nbits=64, prefix=long_to_bytes(counter, 8), initial_value=0)
47+
# in pycrypto Counter used to be an object,
48+
# in pycryptodome it's now only a dict.
49+
# This tries to validate its "type" so we don't feed anything as a counter
50+
if set(counter) != set(Counter.new(64)):
5051
raise TypeError
5152
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
5253

53-
class Counter(object):
54-
def __init__(self, prefix):
55-
self.prefix = prefix
56-
self.val = 0
57-
58-
def inc(self):
59-
self.prefix += 1
60-
self.val = 0
61-
62-
def __setattr__(self, attr, val):
63-
if attr == 'prefix':
64-
self.val = 0
65-
super(Counter, self).__setattr__(attr, val)
66-
67-
def __repr__(self):
68-
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
69-
70-
def byteprefix(self):
71-
return long_to_bytes(self.prefix, 8)
72-
73-
def __call__(self):
74-
bytesuffix = long_to_bytes(self.val, 8)
75-
self.val += 1
76-
return self.byteprefix() + bytesuffix
77-
7854
@common.registerkeytype
7955
class DSAKey(common.PK):
8056
keyType = 0x0000
@@ -107,12 +83,14 @@ def fingerprint(self):
10783
def sign(self, data):
10884
# 2 <= K <= q
10985
K = randrange(2, self.priv.q)
110-
r, s = self.priv.sign(data, K)
86+
M = bytes_to_long(data)
87+
r, s = self.priv._sign(M, K)
11188
return long_to_bytes(r, 20) + long_to_bytes(s, 20)
11289

11390
def verify(self, data, sig):
11491
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])
115-
return self.pub.verify(data, (r, s))
92+
M = bytes_to_long(data)
93+
return self.pub._verify(M, (r, s))
11694

11795
def __hash__(self):
11896
return bytes_to_long(self.fingerprint())
@@ -128,8 +106,8 @@ def __ne__(self, other):
128106
@classmethod
129107
def generate(cls):
130108
privkey = DSA.generate(1024)
131-
return cls((privkey.key.y, privkey.key.g, privkey.key.p, privkey.key.q,
132-
privkey.key.x), private=True)
109+
return cls((privkey.y, privkey.g, privkey.p, privkey.q,
110+
privkey.x), private=True)
133111

134112
@classmethod
135113
def parsePayload(cls, data, private=False):
@@ -143,7 +121,7 @@ def parsePayload(cls, data, private=False):
143121
return cls((y, g, p, q), private=False), data
144122

145123
def getrandbits(k):
146-
return Crypto.Random.random.getrandbits(k)
124+
return random.getrandbits(k)
147125

148126
def randrange(start, stop):
149-
return Crypto.Random.random.randrange(start, stop)
127+
return random.randrange(start, stop)

src/potr/crypt.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
26-
Counter, AESCTR, PK, getrandbits, randrange
26+
AESCTR, PK, getrandbits, randrange
2727
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
2828
from potr import proto
2929

@@ -69,8 +69,8 @@ def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
6969
self.sendmac = sendmac
7070
self.rcvenc = rcvenc
7171
self.rcvmac = rcvmac
72-
self.sendctr = Counter(0)
73-
self.rcvctr = Counter(0)
72+
self.sendctr = 0
73+
self.rcvctr = 0
7474
self.sendmacused = False
7575
self.rcvmacused = False
7676

@@ -177,12 +177,12 @@ def handleDataMessage(self, msg):
177177
sesskey.rcvmacused = True
178178

179179
newCtrPrefix = bytes_to_long(msg.ctr)
180-
if newCtrPrefix <= sesskey.rcvctr.prefix:
180+
if newCtrPrefix <= sesskey.rcvctr:
181181
logger.error('CTR must increase (old %r, new %r)',
182182
sesskey.rcvctr.prefix, newCtrPrefix)
183183
raise InvalidParameterError
184184

185-
sesskey.rcvctr.prefix = newCtrPrefix
185+
sesskey.rcvctr = newCtrPrefix
186186

187187
logger.debug('handle: enc={0!r} mac={1!r} ctr={2!r}' \
188188
.format(sesskey.rcvenc, sesskey.rcvmac, sesskey.rcvctr))
@@ -232,7 +232,7 @@ def createDataMessage(self, message, flags=0, tlvs=None):
232232
tlvs = []
233233

234234
sess = self.sessionkeys[1][0]
235-
sess.sendctr.inc()
235+
sess.sendctr += 1
236236

237237
logger.debug('create: enc={0!r} mac={1!r} ctr={2!r}' \
238238
.format(sess.sendenc, sess.sendmac, sess.sendctr))
@@ -242,7 +242,7 @@ def createDataMessage(self, message, flags=0, tlvs=None):
242242
encmsg = AESCTR(sess.sendenc, sess.sendctr).encrypt(plainBuf)
243243

244244
msg = proto.DataMessage(flags, self.ourKeyid-1, self.theirKeyid,
245-
long_to_bytes(self.ourDHKey.pub), sess.sendctr.byteprefix(),
245+
long_to_bytes(self.ourDHKey.pub), long_to_bytes(sess.sendctr, 8),
246246
encmsg, b'', b''.join(self.savedMacKeys))
247247

248248
self.savedMacKeys = []

tests/test_compatcrypto.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def test_AESCTR_counter_counter(self):
6464
key = potr.utils.long_to_bytes(
6565
potr.compatcrypto.getrandbits(128), 16)
6666

67-
aes_encrypter = potr.compatcrypto.AESCTR(key, potr.compatcrypto.Counter(2013))
67+
aes_encrypter = potr.compatcrypto.AESCTR(key, 2013)
6868
ciphertext = aes_encrypter.encrypt(b'setec astronomy')
6969

70-
aes_decrypter = potr.compatcrypto.AESCTR(key, potr.compatcrypto.Counter(2013))
70+
aes_decrypter = potr.compatcrypto.AESCTR(key, 2013)
7171
self.assertEqual(aes_decrypter.decrypt(ciphertext), b'setec astronomy')
7272

7373
def test_getrandbits(self):

0 commit comments

Comments
 (0)