Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions src/josepy/jwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes # type: ignore
from cryptography.hazmat.primitives import hmac # type: ignore
from cryptography.hazmat.primitives.asymmetric import padding # type: ignore
from cryptography.hazmat.primitives.asymmetric import padding, ec # type: ignore

from josepy import errors, interfaces, jwk

Expand Down Expand Up @@ -75,7 +75,6 @@ def __repr__(self):


class _JWAHS(JWASignature):

kty = jwk.JWKOct

def __init__(self, name, hash_):
Expand All @@ -100,7 +99,6 @@ def verify(self, key, msg, sig):


class _JWARSA(object):

kty = jwk.JWKRSA
padding = NotImplemented
hash = NotImplemented
Expand Down Expand Up @@ -163,15 +161,51 @@ def __init__(self, name, hash_):
self.hash = hash_()


class _JWAES(JWASignature): # pylint: disable=abstract-class-not-used
class _JWAEC(JWASignature):
kty = jwk.JWKEC

# TODO: implement ES signatures
def __init__(self, name, hash_):
super(_JWAEC, self).__init__(name)
self.hash = hash_()

def sign(self, key, msg): # pragma: no cover
raise NotImplementedError()
def sign(self, key, msg):
"""Sign the ``msg`` using ``key``."""
# If cryptography library supports new style api (v1.4 and later)
new_api = hasattr(key, 'sign')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just inline this?

try:
if new_api:
return key.sign(msg, ec.ECDSA(self.hash))
signer = key.signer(ec.ECDSA(self.hash))
except AttributeError as error:
logger.debug(error, exc_info=True)
raise errors.Error('Public key cannot be used for signing')
except ValueError as error: # digest too large
logger.debug(error, exc_info=True)
raise errors.Error(str(error))
signer.update(msg)
try:
return signer.finalize()
except ValueError as error:
logger.debug(error, exc_info=True)
raise errors.Error(str(error))

def verify(self, key, msg, sig): # pragma: no cover
raise NotImplementedError()
def verify(self, key, msg, sig):
"""Verify the ``msg` and ``sig`` using ``key``."""
# If cryptography library supports new style api (v1.4 and later)
new_api = hasattr(key, 'verify')
if not new_api:
verifier = key.verifier(sig, ec.ECDSA(self.hash))
verifier.update(msg)
try:
if new_api:
key.verify(sig, msg, ec.ECDSA(self.hash))
else:
verifier.verify()
except cryptography.exceptions.InvalidSignature as error:
logger.debug(error, exc_info=True)
return False
else:
return True


#: HMAC using SHA-256
Expand All @@ -196,8 +230,8 @@ def verify(self, key, msg, sig): # pragma: no cover
PS512 = JWASignature.register(_JWAPS('PS512', hashes.SHA512))

#: ECDSA using P-256 and SHA-256
ES256 = JWASignature.register(_JWAES('ES256'))
ES256 = JWASignature.register(_JWAEC('ES256', hashes.SHA256))
#: ECDSA using P-384 and SHA-384
ES384 = JWASignature.register(_JWAES('ES384'))
ES384 = JWASignature.register(_JWAEC('ES384', hashes.SHA384))
#: ECDSA using P-521 and SHA-512
ES512 = JWASignature.register(_JWAES('ES512'))
ES512 = JWASignature.register(_JWAEC('ES512', hashes.SHA512))
79 changes: 79 additions & 0 deletions src/josepy/jwa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
RSA256_KEY = test_util.load_rsa_private_key('rsa256_key.pem')
RSA512_KEY = test_util.load_rsa_private_key('rsa512_key.pem')
RSA1024_KEY = test_util.load_rsa_private_key('rsa1024_key.pem')
EC_P256_KEY = test_util.load_ec_private_key('ec_p256_key.pem')
EC_P384_KEY = test_util.load_ec_private_key('ec_p384_key.pem')
EC_P521_KEY = test_util.load_ec_private_key('ec_p521_key.pem')


class JWASignatureTest(unittest.TestCase):
Expand Down Expand Up @@ -133,5 +136,81 @@ def test_verify_old_api(self):
verifier.verify.called]))


class JWAECTest(unittest.TestCase):

def test_sign_no_private_part(self):
from josepy.jwa import ES256
self.assertRaises(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be on one line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by a5fb321

errors.Error, ES256.sign, EC_P256_KEY.public_key(), b'foo')

def test_es256_sign_and_verify(self):
from josepy.jwa import ES256
message = b'foo'
signature = ES256.sign(EC_P256_KEY, message)
self.assertTrue(ES256.verify(EC_P256_KEY.public_key(), message, signature))

def test_es384_sign_and_verify(self):
from josepy.jwa import ES384
message = b'foo'
signature = ES384.sign(EC_P384_KEY, message)
self.assertTrue(ES384.verify(EC_P384_KEY.public_key(), message, signature))

def test_es512_sign_and_verify(self):
from josepy.jwa import ES512
message = b'foo'
signature = ES512.sign(EC_P521_KEY, message)
self.assertTrue(ES512.verify(EC_P521_KEY.public_key(), message, signature))

def test_verify_with_wrong_jwa(self):
from josepy.jwa import ES256, ES384
message = b'foo'
signature = ES256.sign(EC_P256_KEY, message)
self.assertFalse(ES384.verify(EC_P384_KEY.public_key(), message, signature))

def test_verify_with_different_key(self):
from josepy.jwa import ES256
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.backends import default_backend

message = b'foo'
signature = ES256.sign(EC_P256_KEY, message)
different_key = ec.generate_private_key(ec.SECP256R1, default_backend())
self.assertFalse(ES256.verify(different_key.public_key(), message, signature))

def test_sign_new_api(self):
from josepy.jwa import ES256
key = mock.MagicMock()
ES256.sign(key, "message")
self.assertTrue(key.sign.called)

def test_sign_old_api(self):
from josepy.jwa import ES256
key = mock.MagicMock(spec=[u'signer'])
signer = mock.MagicMock()
key.signer.return_value = signer
ES256.sign(key, "message")
self.assertTrue(all([
key.signer.called,
signer.update.called,
signer.finalize.called]))

def test_verify_new_api(self):
from josepy.jwa import ES256
key = mock.MagicMock()
ES256.verify(key, "message", "signature")
self.assertTrue(key.verify.called)

def test_verify_old_api(self):
from josepy.jwa import ES256
key = mock.MagicMock(spec=[u'verifier'])
verifier = mock.MagicMock()
key.verifier.return_value = verifier
ES256.verify(key, "message", "signature")
self.assertTrue(all([
key.verifier.called,
verifier.update.called,
verifier.verify.called]))


if __name__ == '__main__':
unittest.main() # pragma: no cover
154 changes: 129 additions & 25 deletions src/josepy/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,30 +120,6 @@ def load(cls, data, password=None, backend=None):
raise errors.Error('Unsupported algorithm: {0}'.format(key.__class__))


@JWK.register
class JWKES(JWK): # pragma: no cover
# pylint: disable=abstract-class-not-used
"""ES JWK.

.. warning:: This is not yet implemented!

"""
typ = 'ES'
cryptography_key_types = (
ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
required = ('crv', JWK.type_field_name, 'x', 'y')

def fields_to_partial_json(self):
raise NotImplementedError()

@classmethod
def fields_from_json(cls, jobj):
raise NotImplementedError()

def public_key(self):
raise NotImplementedError()


@JWK.register
class JWKOct(JWK):
"""Symmetric JWK."""
Expand Down Expand Up @@ -194,6 +170,7 @@ def _encode_param(cls, data):
:rtype: unicode

"""

def _leading_zeros(arg):
if len(arg) % 2:
return '0' + arg
Expand Down Expand Up @@ -248,7 +225,7 @@ def fields_from_json(cls, jobj):

key = rsa.RSAPrivateNumbers(
p, q, d, dp, dq, qi, public_numbers).private_key(
default_backend())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My IDE must have auto-formatted this. I will revert this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by 2c3cabe

default_backend())

return cls(key=key)

Expand All @@ -275,3 +252,130 @@ def fields_to_partial_json(self):
}
return dict((key, self._encode_param(value))
for key, value in six.iteritems(params))


@JWK.register
class JWKEC(JWK):
"""EC JWK.

:ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey`
or :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey` wrapped
in :class:`~josepy.util.ComparableRSAKey`

"""
typ = 'EC'
__slots__ = ('key',)
cryptography_key_types = (
ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
required = ('crv', JWK.type_field_name, 'x', 'y')

def __init__(self, *args, **kwargs):
if 'key' in kwargs and not isinstance(
kwargs['key'], util.ComparableECKey):
kwargs['key'] = util.ComparableECKey(kwargs['key'])
super(JWKEC, self).__init__(*args, **kwargs)

@classmethod
def _encode_param(cls, data):
"""Encode Base64urlUInt.
:type data: long
:rtype: unicode
"""

def _leading_zeros(arg):
if len(arg) % 2:
return '0' + arg
return arg

return json_util.encode_b64jose(binascii.unhexlify(
_leading_zeros(hex(data)[2:].rstrip('L'))))

@classmethod
def _decode_param(cls, data, name, valid_lengths):
"""Decode Base64urlUInt."""
try:
binary = json_util.decode_b64jose(data)
if len(binary) not in valid_lengths:
raise errors.DeserializationError(
'Expected parameter "{name}" to be {valid_lengths} bytes '
'after base64-decoding; got {length} bytes instead'.format(
name=name, valid_lengths=valid_lengths, length=len(binary))
)
return int(binascii.hexlify(binary), 16)
except ValueError: # invalid literal for long() with base 16
raise errors.DeserializationError()

@classmethod
def _curve_name_to_crv(cls, curve_name):
if curve_name == 'secp256r1':
return 'P-256'
if curve_name == 'secp384r1':
return 'P-384'
if curve_name == 'secp521r1':
return 'P-521'
raise errors.SerializationError()

@classmethod
def _crv_to_curve(cls, crv):
# crv is case-sensitive
if crv == 'P-256':
return ec.SECP256R1()
if crv == 'P-384':
return ec.SECP384R1()
if crv == 'P-521':
return ec.SECP521R1()
raise errors.DeserializationError()

@classmethod
def _expected_length_for_curve(cls, curve):
if isinstance(curve, ec.SECP256R1):
return range(32, 33)
elif isinstance(curve, ec.SECP384R1):
return range(48, 49)
elif isinstance(curve, ec.SECP521R1):
return range(63, 67)

def fields_to_partial_json(self):
params = {}
if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey):
public = self.key.public_numbers()
elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey):
private = self.key.private_numbers()
public = self.key.public_key().public_numbers()
params.update({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just do params['d'] = private.private_value ? It's ~3 times faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by 6e8d041

'd': private.private_value,
})
else:
raise errors.SerializationError(
'Supplied key is neither of type EllipticCurvePublicKey nor EllipticCurvePrivateKey')
params.update({
'x': public.x,
'y': public.y,
})
params = dict((key, self._encode_param(value))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict comprehensions are faster, and probably also more Pythonic. It works in Python 2.

{key: self._encode_param(value) for key, value in six.iteritems(params)}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed by 6e8d041

for key, value in six.iteritems(params))
params['crv'] = self._curve_name_to_crv(public.curve.name)
return params

@classmethod
def fields_from_json(cls, jobj):
# pylint: disable=invalid-name
curve = cls._crv_to_curve(jobj['crv'])
expected_length = cls._expected_length_for_curve(curve)
x, y = (cls._decode_param(jobj[n], n, expected_length) for n in ('x', 'y'))
public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve)
if 'd' not in jobj: # public key
key = public_numbers.public_key(default_backend())
else: # private key
d = cls._decode_param(jobj['d'], 'd', expected_length)
key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key(
default_backend())
return cls(key=key)

def public_key(self):
# Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key()
if hasattr(self.key, 'public_key'):
key = self.key.public_key()
else:
key = self.key.public_numbers().public_key(default_backend())
return type(self)(key=key)
Loading