Skip to content

Add type hints for Repository.branches, Repository.references and Repository.remotes #1384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 93 additions & 7 deletions pygit2/_pygit2.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator, Literal, Optional, overload
from typing import Iterator, Literal, Optional, overload, Type, TypedDict
from io import IOBase
from . import Index
from .enums import (
Expand All @@ -19,6 +19,10 @@ from .enums import (
ResetMode,
SortMode,
)
from collections.abc import Generator

from .repository import BaseRepository
from .remotes import Remote

GIT_OBJ_BLOB = Literal[3]
GIT_OBJ_COMMIT = Literal[1]
Expand Down Expand Up @@ -73,15 +77,15 @@ class Reference:
def delete(self) -> None: ...
def log(self) -> Iterator[RefLogEntry]: ...
@overload
def peel(self, type: 'Literal[GIT_OBJ_COMMIT]') -> 'Commit': ...
def peel(self, type: 'Literal[GIT_OBJ_COMMIT] | Type[Commit]') -> 'Commit': ...
@overload
def peel(self, type: 'Literal[GIT_OBJ_TREE]') -> 'Tree': ...
def peel(self, type: 'Literal[GIT_OBJ_TREE] | Type[Tree]') -> 'Tree': ...
@overload
def peel(self, type: 'Literal[GIT_OBJ_TAG]') -> 'Tag': ...
def peel(self, type: 'Literal[GIT_OBJ_TAG] | Type[Tag]') -> 'Tag': ...
@overload
def peel(self, type: 'Literal[GIT_OBJ_BLOB]') -> 'Blob': ...
def peel(self, type: 'Literal[GIT_OBJ_BLOB] | Type[Blob]') -> 'Blob': ...
@overload
def peel(self, type: 'None') -> 'Commit|Tree|Blob': ...
def peel(self, type: 'None' = None) -> 'Commit|Tree|Tag|Blob': ...
def rename(self, new_name: str) -> None: ...
def resolve(self) -> Reference: ...
def set_target(self, target: _OidArg, message: str = ...) -> None: ...
Expand Down Expand Up @@ -122,7 +126,7 @@ class Branch(Reference):
def delete(self) -> None: ...
def is_checked_out(self) -> bool: ...
def is_head(self) -> bool: ...
def rename(self, name: str, force: bool = False) -> None: ...
def rename(self, name: str, force: bool = False) -> 'Branch': ... # type: ignore[override]

class Commit(Object):
author: Signature
Expand Down Expand Up @@ -329,6 +333,80 @@ class RefdbBackend:
class RefdbFsBackend(RefdbBackend):
def __init__(self, *args, **kwargs) -> None: ...

class References:
def __init__(self, repository: BaseRepository) -> None: ...
def __getitem__(self, name: str) -> Reference: ...
def get(self, key: str) -> Reference: ...
def __iter__(self) -> Iterator[str]: ...
def iterator(
self, references_return_type: ReferenceFilter = ...
) -> Iterator[Reference]: ...
def create(self, name: str, target: _OidArg, force: bool = False) -> Reference: ...
def delete(self, name: str) -> None: ...
def __contains__(self, name: str) -> bool: ...
@property
def objects(self) -> list[Reference]: ...
def compress(self) -> None: ...

_Proxy = None | Literal[True] | str

class _StrArray:
# incomplete
count: int

class ProxyOpts:
# incomplete
type: object
url: str

class PushOptions:
version: int
pb_parallelism: int
callbacks: object # TODO
proxy_opts: ProxyOpts
follow_redirects: object # TODO
custom_headers: _StrArray
remote_push_options: _StrArray

class _LsRemotesDict(TypedDict):
local: bool
loid: Oid | None
name: str | None
symref_target: str | None
oid: Oid

class RemoteCollection:
def __init__(self, repo: BaseRepository) -> None: ...
def __len__(self) -> int: ...
def __iter__(self): ...
def __getitem__(self, name: str | int) -> Remote: ...
def names(self) -> Generator[str, None, None]: ...
def create(self, name: str, url: str, fetch: str | None = None) -> Remote: ...
def create_anonymous(self, url: str) -> Remote: ...
def rename(self, name: str, new_name: str) -> list[str]: ...
def delete(self, name: str) -> None: ...
def set_url(self, name: str, url: str) -> None: ...
def set_push_url(self, name: str, url: str) -> None: ...
def add_fetch(self, name: str, refspec: str) -> None: ...
def add_push(self, name: str, refspec: str) -> None: ...

class Branches:
local: 'Branches'
remote: 'Branches'
def __init__(
self,
repository: BaseRepository,
flag: BranchType = ...,
commit: Commit | _OidArg | None = None,
) -> None: ...
def __getitem__(self, name: str) -> Branch: ...
def get(self, key: str) -> Branch: ...
def __iter__(self) -> Iterator[str]: ...
def create(self, name: str, commit: Commit, force: bool = False) -> Branch: ...
def delete(self, name: str) -> None: ...
def with_commit(self, commit: Commit | _OidArg | None) -> 'Branches': ...
def __contains__(self, name: _OidArg) -> bool: ...

class Repository:
_pointer: bytes
default_signature: Signature
Expand All @@ -342,10 +420,14 @@ class Repository:
path: str
refdb: Refdb
workdir: str
references: References
remotes: RemoteCollection
branches: Branches
def __init__(self, *args, **kwargs) -> None: ...
def TreeBuilder(self, src: Tree | _OidArg = ...) -> TreeBuilder: ...
def _disown(self, *args, **kwargs) -> None: ...
def _from_c(self, *args, **kwargs) -> None: ...
def __getitem__(self, key: str | bytes | Oid | Reference) -> Commit: ...
def add_worktree(self, name: str, path: str, ref: Reference = ...) -> Worktree: ...
def applies(
self,
Expand Down Expand Up @@ -394,6 +476,9 @@ class Repository:
ref: str = 'refs/notes/commits',
force: bool = False,
) -> Oid: ...
def create_reference(
self, name: str, target: _OidArg, force: bool = False
) -> Reference: ...
def create_reference_direct(
self, name: str, target: _OidArg, force: bool, message: Optional[str] = None
) -> Reference: ...
Expand Down Expand Up @@ -443,6 +528,7 @@ class Repository:
def revparse(self, revspec: str) -> RevSpec: ...
def revparse_ext(self, revision: str) -> tuple[Object, Reference]: ...
def revparse_single(self, revision: str) -> Object: ...
def set_ident(self, name: str, email: str) -> None: ...
def set_odb(self, odb: Odb) -> None: ...
def set_refdb(self, refdb: Refdb) -> None: ...
def status(
Expand Down
51 changes: 32 additions & 19 deletions pygit2/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,29 @@
# Standard Library
from contextlib import contextmanager
from functools import wraps
from typing import Optional, Union
from typing import Optional, Union, TYPE_CHECKING, Callable, Generator

# pygit2
from ._pygit2 import Oid, DiffFile
from .enums import CheckoutNotify, CheckoutStrategy, CredentialType, StashApplyProgress
from .errors import check_error, Passthrough
from .ffi import ffi, C
from .utils import maybe_string, to_bytes, ptr_to_bytes, StrArray
from .credentials import Username, UserPass, Keypair

_Credentials = Username | UserPass | Keypair

if TYPE_CHECKING:
from .remotes import TransferProgress
from ._pygit2 import ProxyOpts, PushOptions
#
# The payload is the way to pass information from the pygit2 API, through
# libgit2, to the Python callbacks. And back.
#


class Payload:
def __init__(self, **kw: object):
def __init__(self, **kw: object) -> None:
for key, value in kw.items():
setattr(self, key, value)
self._stored_exception = None
Expand Down Expand Up @@ -113,12 +118,18 @@ class RemoteCallbacks(Payload):
RemoteCallbacks(certificate=certificate).
"""

def __init__(self, credentials=None, certificate_check=None):
push_options: 'PushOptions'

def __init__(
self,
credentials: _Credentials | None = None,
certificate_check: Callable[[None, bool, bytes], bool] | None = None,
) -> None:
super().__init__()
if credentials is not None:
self.credentials = credentials
self.credentials = credentials # type: ignore[method-assign, assignment]
if certificate_check is not None:
self.certificate_check = certificate_check
self.certificate_check = certificate_check # type: ignore[method-assign, assignment]

def sideband_progress(self, string: str) -> None:
"""
Expand All @@ -136,7 +147,7 @@ def credentials(
url: str,
username_from_url: Union[str, None],
allowed_types: CredentialType,
):
) -> _Credentials:
"""
Credentials callback. If the remote server requires authentication,
this function will be called and its return value used for
Expand All @@ -159,7 +170,7 @@ def credentials(
"""
raise Passthrough

def certificate_check(self, certificate: None, valid: bool, host: str) -> bool:
def certificate_check(self, certificate: None, valid: bool, host: bytes) -> bool:
"""
Certificate callback. Override with your own function to determine
whether to accept the server's certificate.
Expand All @@ -181,7 +192,7 @@ def certificate_check(self, certificate: None, valid: bool, host: str) -> bool:

raise Passthrough

def transfer_progress(self, stats):
def transfer_progress(self, stats: 'TransferProgress') -> None:
"""
During the download of new data, this will be regularly called with
the indexer's progress.
Expand All @@ -196,7 +207,7 @@ def transfer_progress(self, stats):

def push_transfer_progress(
self, objects_pushed: int, total_objects: int, bytes_pushed: int
):
) -> None:
"""
During the upload portion of a push, this will be regularly called
with progress information.
Expand All @@ -207,7 +218,7 @@ def push_transfer_progress(
Override with your own function to report push transfer progress.
"""

def update_tips(self, refname, old, new):
def update_tips(self, refname: str, old: Oid, new: Oid) -> None:
"""
Update tips callback. Override with your own function to report
reference updates.
Expand All @@ -224,7 +235,7 @@ def update_tips(self, refname, old, new):
The reference's new value.
"""

def push_update_reference(self, refname, message):
def push_update_reference(self, refname: str, message: str) -> None:
"""
Push update reference callback. Override with your own function to
report the remote's acceptance or rejection of reference updates.
Expand All @@ -244,7 +255,7 @@ class CheckoutCallbacks(Payload):
in your class, which you can then pass to checkout operations.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()

def checkout_notify_flags(self) -> CheckoutNotify:
Expand Down Expand Up @@ -275,7 +286,7 @@ def checkout_notify(
baseline: Optional[DiffFile],
target: Optional[DiffFile],
workdir: Optional[DiffFile],
):
) -> None:
"""
Checkout will invoke an optional notification callback for
certain cases - you pick which ones via `checkout_notify_flags`.
Expand All @@ -290,7 +301,9 @@ def checkout_notify(
"""
pass

def checkout_progress(self, path: str, completed_steps: int, total_steps: int):
def checkout_progress(
self, path: str, completed_steps: int, total_steps: int
) -> None:
"""
Optional callback to notify the consumer of checkout progress.
"""
Expand All @@ -304,7 +317,7 @@ class StashApplyCallbacks(CheckoutCallbacks):
in your class, which you can then pass to stash apply or pop operations.
"""

def stash_apply_progress(self, progress: StashApplyProgress):
def stash_apply_progress(self, progress: StashApplyProgress) -> None:
"""
Stash application progress notification function.

Expand Down Expand Up @@ -373,9 +386,9 @@ def git_fetch_options(payload, opts=None):
@contextmanager
def git_proxy_options(
payload: object,
opts: object | None = None,
opts: Optional['ProxyOpts'] = None,
proxy: None | bool | str = None,
):
) -> Generator['ProxyOpts', None, None]:
if opts is None:
opts = ffi.new('git_proxy_options *')
C.git_proxy_options_init(opts, C.GIT_PROXY_OPTIONS_VERSION)
Expand All @@ -386,8 +399,8 @@ def git_proxy_options(
elif type(proxy) is str:
opts.type = C.GIT_PROXY_SPECIFIED
# Keep url in memory, otherwise memory is freed and bad things happen
payload.__proxy_url = ffi.new('char[]', to_bytes(proxy))
opts.url = payload.__proxy_url
payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[attr-defined, no-untyped-call]
opts.url = payload.__proxy_url # type: ignore[attr-defined]
else:
raise TypeError('Proxy must be None, True, or a string')
yield opts
Expand Down
14 changes: 10 additions & 4 deletions pygit2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def credential_type(self) -> CredentialType:
return CredentialType.USERNAME

@property
def credential_tuple(self):
def credential_tuple(self) -> tuple[str]:
return (self._username,)

def __call__(
Expand All @@ -74,7 +74,7 @@ def credential_type(self) -> CredentialType:
return CredentialType.USERPASS_PLAINTEXT

@property
def credential_tuple(self):
def credential_tuple(self) -> tuple[str, str]:
return (self._username, self._password)

def __call__(
Expand Down Expand Up @@ -107,7 +107,11 @@ class Keypair:
"""

def __init__(
self, username: str, pubkey: str | Path, privkey: str | Path, passphrase: str
self,
username: str,
pubkey: str | Path | None,
privkey: str | Path | None,
passphrase: str | None,
):
self._username = username
self._pubkey = pubkey
Expand All @@ -119,7 +123,9 @@ def credential_type(self) -> CredentialType:
return CredentialType.SSH_KEY

@property
def credential_tuple(self):
def credential_tuple(
self,
) -> tuple[str, str | Path | None, str | Path | None, str | None]:
return (self._username, self._pubkey, self._privkey, self._passphrase)

def __call__(
Expand Down
12 changes: 10 additions & 2 deletions pygit2/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Boston, MA 02110-1301, USA.

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

# Import from pygit2
from ._pygit2 import Oid
Expand All @@ -49,7 +49,15 @@
class TransferProgress:
"""Progress downloading and indexing data during a fetch."""

def __init__(self, tp):
total_objects: int
indexed_objects: int
received_objects: int
local_objects: int
total_deltas: int
indexed_deltas: int
received_bytes: int

def __init__(self, tp: Any) -> None:
self.total_objects = tp.total_objects
"""Total number of objects to download"""

Expand Down
Loading
Loading