Skip to content

Commit ddab546

Browse files
chadrikwoile
authored andcommitted
fix: Improve type annotations
1 parent 74153b7 commit ddab546

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

commitizen/bump.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
import re
55
from collections import OrderedDict
66
from string import Template
7+
from typing import cast
78

89
from commitizen.defaults import MAJOR, MINOR, PATCH, bump_message, encoding
910
from commitizen.exceptions import CurrentVersionNotFoundError
1011
from commitizen.git import GitCommit, smart_open
11-
from commitizen.version_schemes import DEFAULT_SCHEME, Version, VersionScheme
12+
from commitizen.version_schemes import DEFAULT_SCHEME, Increment, Version, VersionScheme
1213

1314
VERSION_TYPES = [None, PATCH, MINOR, MAJOR]
1415

1516

1617
def find_increment(
1718
commits: list[GitCommit], regex: str, increments_map: dict | OrderedDict
18-
) -> str | None:
19+
) -> Increment | None:
1920
if isinstance(increments_map, dict):
2021
increments_map = OrderedDict(increments_map)
2122

@@ -42,7 +43,7 @@ def find_increment(
4243
if increment == MAJOR:
4344
break
4445

45-
return increment
46+
return cast(Increment, increment)
4647

4748

4849
def update_version_in_files(

commitizen/commands/bump.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from commitizen.providers import get_provider
2727
from commitizen.version_schemes import (
2828
get_version_scheme,
29+
Increment,
2930
InvalidVersion,
31+
Prerelease,
3032
)
3133

3234
logger = getLogger("commitizen")
@@ -112,7 +114,7 @@ def is_initial_tag(self, current_tag_version: str, is_yes: bool = False) -> bool
112114
is_initial = questionary.confirm("Is this the first tag created?").ask()
113115
return is_initial
114116

115-
def find_increment(self, commits: list[git.GitCommit]) -> str | None:
117+
def find_increment(self, commits: list[git.GitCommit]) -> Increment | None:
116118
# Update the bump map to ensure major version doesn't increment.
117119
is_major_version_zero: bool = self.bump_settings["major_version_zero"]
118120
# self.cz.bump_map = defaults.bump_map_major_version_zero
@@ -132,7 +134,7 @@ def find_increment(self, commits: list[git.GitCommit]) -> str | None:
132134
)
133135
return increment
134136

135-
def __call__(self): # noqa: C901
137+
def __call__(self) -> None: # noqa: C901
136138
"""Steps executed to bump."""
137139
provider = get_provider(self.config)
138140

@@ -149,11 +151,11 @@ def __call__(self): # noqa: C901
149151

150152
dry_run: bool = self.arguments["dry_run"]
151153
is_yes: bool = self.arguments["yes"]
152-
increment: str | None = self.arguments["increment"]
153-
prerelease: str | None = self.arguments["prerelease"]
154+
increment: Increment | None = self.arguments["increment"]
155+
prerelease: Prerelease | None = self.arguments["prerelease"]
154156
devrelease: int | None = self.arguments["devrelease"]
155157
is_files_only: bool | None = self.arguments["files_only"]
156-
is_local_version: bool | None = self.arguments["local_version"]
158+
is_local_version: bool = self.arguments["local_version"]
157159
manual_version = self.arguments["manual_version"]
158160
build_metadata = self.arguments["build_metadata"]
159161

@@ -404,7 +406,7 @@ def __call__(self): # noqa: C901
404406
else:
405407
out.success("Done!")
406408

407-
def _get_commit_args(self):
409+
def _get_commit_args(self) -> str:
408410
commit_args = ["-a"]
409411
if self.no_verify:
410412
commit_args.append("--no-verify")

commitizen/version_schemes.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44
import sys
55
import warnings
66
from itertools import zip_longest
7-
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, Type, cast, runtime_checkable
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
ClassVar,
11+
Literal,
12+
Protocol,
13+
Type,
14+
cast,
15+
runtime_checkable,
16+
)
817

918
import importlib_metadata as metadata
1019
from packaging.version import InvalidVersion # noqa: F401: Rexpose the common exception
@@ -28,6 +37,8 @@
2837
from typing import Self
2938

3039

40+
Increment: TypeAlias = Literal["MAJOR", "MINOR", "PATCH"]
41+
Prerelease: TypeAlias = Literal["alpha", "beta", "rc"]
3142
DEFAULT_VERSION_PARSER = r"v?(?P<version>([0-9]+)\.([0-9]+)\.([0-9]+)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+[0-9A-Za-z-]+)?(\w+)?)"
3243

3344

@@ -113,8 +124,8 @@ def __ne__(self, other: object) -> bool:
113124

114125
def bump(
115126
self,
116-
increment: str,
117-
prerelease: str | None = None,
127+
increment: Increment | None,
128+
prerelease: Prerelease | None = None,
118129
prerelease_offset: int = 0,
119130
devrelease: int | None = None,
120131
is_local_version: bool = False,
@@ -203,7 +214,7 @@ def generate_build_metadata(self, build_metadata: str | None) -> str:
203214

204215
return f"+{build_metadata}"
205216

206-
def increment_base(self, increment: str | None = None) -> str:
217+
def increment_base(self, increment: Increment | None = None) -> str:
207218
prev_release = list(self.release)
208219
increments = [MAJOR, MINOR, PATCH]
209220
base = dict(zip_longest(increments, prev_release, fillvalue=0))
@@ -222,8 +233,8 @@ def increment_base(self, increment: str | None = None) -> str:
222233

223234
def bump(
224235
self,
225-
increment: str,
226-
prerelease: str | None = None,
236+
increment: Increment | None,
237+
prerelease: Prerelease | None = None,
227238
prerelease_offset: int = 0,
228239
devrelease: int | None = None,
229240
is_local_version: bool = False,

0 commit comments

Comments
 (0)