Skip to content

Commit 544f489

Browse files
committed
jwcrypto: type check generation of RSA keys
1 parent 2967d75 commit 544f489

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

stubs/jwcrypto/jwcrypto/jwk.pyi

+25-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Callable, Sequence
22
from enum import Enum
33
from typing import Any, Literal, NamedTuple, TypeVar, overload
4-
from typing_extensions import Self, deprecated
4+
from typing_extensions import Self, TypeAlias, deprecated
55

66
from cryptography.hazmat.primitives import hashes
77
from cryptography.hazmat.primitives.asymmetric import ec, rsa
@@ -46,7 +46,8 @@ class _X448_CURVE(NamedTuple):
4646
pubkey: UnimplementedOKPCurveKey
4747
privkey: UnimplementedOKPCurveKey
4848

49-
JWKTypesRegistry: dict[str, str]
49+
_JWKKeyTypeSupported: TypeAlias = Literal["oct", "RSA", "EC", "OKP"]
50+
JWKTypesRegistry: dict[_JWKKeyTypeSupported, str]
5051

5152
class ParmType(Enum):
5253
name = "A string with a name" # pyright: ignore[reportAssignmentType]
@@ -63,8 +64,12 @@ class JWKParameter(NamedTuple):
6364
JWKValuesRegistry: dict[str, dict[str, JWKParameter]]
6465
JWKParamsRegistry: dict[str, JWKParameter]
6566
JWKEllipticCurveRegistry: dict[str, str]
66-
JWKUseRegistry: dict[str, str]
67-
JWKOperationsRegistry: dict[str, str]
67+
_JWKUseSupported: TypeAlias = Literal["sig", "enc"]
68+
JWKUseRegistry: dict[_JWKUseSupported, str]
69+
_JWKOperationSupported: TypeAlias = Literal[
70+
"sign", "verify", "encrypt", "decrypt", "wrapKey", "unwrapKey", "deriveKey", "deriveBits"
71+
]
72+
JWKOperationsRegistry: dict[_JWKOperationSupported, str]
6873
JWKpycaCurveMap: dict[str, str]
6974
IANANamedInformationHashAlgorithmRegistry: dict[
7075
str,
@@ -79,7 +84,6 @@ IANANamedInformationHashAlgorithmRegistry: dict[
7984
| hashes.BLAKE2b
8085
| None,
8186
]
82-
JWKKeyTypeSupported = Literal["oct", "RSA", "EC", "OKP"]
8387

8488
class InvalidJWKType(JWException):
8589
value: str | None
@@ -103,8 +107,22 @@ class JWK(dict[str, Any]):
103107
# function. The possible arguments depend on the value of `kty`.
104108
# TODO: Add overloads for the individual `kty` values.
105109
@classmethod
106-
def generate(cls, kty: JWKKeyTypeSupported, **kwargs) -> Self: ...
107-
def generate_key(self, kty: JWKKeyTypeSupported, **kwargs) -> None: ...
110+
@overload
111+
def generate(
112+
cls,
113+
*,
114+
kty: Literal["RSA"],
115+
public_exponent: int | None = None,
116+
size: int | None = None,
117+
kid: str | None = None,
118+
alg: str | None = None,
119+
use: _JWKUseSupported | None = None,
120+
key_ops: list[_JWKOperationSupported] | None = None,
121+
) -> Self: ...
122+
@classmethod
123+
@overload
124+
def generate(cls, *, kty: _JWKKeyTypeSupported, **kwargs) -> Self: ...
125+
def generate_key(self, *, kty: _JWKKeyTypeSupported, **kwargs) -> None: ...
108126
def import_key(self, **kwargs) -> None: ...
109127
@classmethod
110128
def from_json(cls, key) -> Self: ...

0 commit comments

Comments
 (0)