Skip to content

Commit

Permalink
Merge pull request #142 from bdpedigo/return-types
Browse files Browse the repository at this point in the history
Add return type hints to `chunkedgraph`
  • Loading branch information
fcollman authored Jan 17, 2024
2 parents caae959 + a2c42c9 commit 0e90f94
Showing 1 changed file with 45 additions and 30 deletions.
75 changes: 45 additions & 30 deletions caveclient/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import datetime
import json
import logging
from typing import Iterable
import sys
from typing import Iterable, Union
from urllib.parse import urlencode

import networkx as nx
Expand All @@ -18,6 +19,10 @@
default_global_server_address,
)

# get the version of python at runtime, decide how to specify tuple types

if sys.version_info < (3, 8):
from typing import Tuple as tuple

SERVER_KEY = "cg_server_address"

Expand Down Expand Up @@ -202,7 +207,7 @@ def _process_timestamp(self, timestamp):
else:
return timestamp

def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None):
def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None) -> np.ndarray:
"""Get the root ID for a list of supervoxels.
Parameters
Expand Down Expand Up @@ -231,7 +236,7 @@ def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None):
handle_response(response, as_json=False)
return np.frombuffer(response.content, dtype=np.uint64)

def get_root_id(self, supervoxel_id, timestamp=None, level2=False):
def get_root_id(self, supervoxel_id, timestamp=None, level2=False) -> np.int64:
"""Get the root ID for a specified supervoxel.
Parameters
Expand All @@ -258,7 +263,7 @@ def get_root_id(self, supervoxel_id, timestamp=None, level2=False):
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response, as_json=True)["root_id"])

def get_merge_log(self, root_id):
def get_merge_log(self, root_id) -> list:
"""Get the merge log (splits and merges) for an object.
Parameters
Expand All @@ -278,7 +283,7 @@ def get_merge_log(self, root_id):
response = self.session.get(url)
return handle_response(response)

def get_change_log(self, root_id, filtered=True):
def get_change_log(self, root_id, filtered=True) -> dict:
"""Get the change log (splits and merges) for an object.
Parameters
Expand Down Expand Up @@ -323,7 +328,7 @@ def get_user_operations(
timestamp_start: datetime.datetime,
include_undo: bool = True,
timestamp_end: datetime.datetime = None,
):
) -> pd.DataFrame:
"""
Get operation details for a user ID. Currently, this is only available to
admins.
Expand Down Expand Up @@ -379,7 +384,7 @@ def get_user_operations(
)
return df

def get_tabular_change_log(self, root_ids, filtered=True):
def get_tabular_change_log(self, root_ids, filtered=True) -> dict:
"""Get a detailed changelog for neurons.
Parameters
Expand Down Expand Up @@ -435,7 +440,7 @@ def get_tabular_change_log(self, root_ids, filtered=True):

return changelog_dict

def get_leaves(self, root_id, bounds=None, stop_layer: int = None):
def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray:
"""Get all supervoxels for a root ID.
Parameters
Expand Down Expand Up @@ -465,7 +470,7 @@ def get_leaves(self, root_id, bounds=None, stop_layer: int = None):
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response)["leaf_ids"])

def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)):
def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)) -> None:
"""Perform a merge on the chunked graph.
Parameters
Expand Down Expand Up @@ -495,7 +500,7 @@ def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)):
)
handle_response(response)

def undo_operation(self, operation_id):
def undo_operation(self, operation_id) -> dict:
"""Undo an operation.
Parameters
Expand Down Expand Up @@ -529,7 +534,7 @@ def execute_split(
root_id,
source_supervoxels=None,
sink_supervoxels=None,
):
) -> tuple[int, list]:
"""Execute a multicut split based on points or supervoxels.
Parameters
Expand Down Expand Up @@ -580,7 +585,7 @@ def preview_split(
source_supervoxels=None,
sink_supervoxels=None,
return_additional_ccs=False,
):
) -> tuple[list, list, bool, list]:
"""Get supervoxel connected components from a preview multicut split.
Parameters
Expand Down Expand Up @@ -643,7 +648,7 @@ def preview_split(
else:
return source_cc, sink_cc, success

def get_children(self, node_id):
def get_children(self, node_id) -> np.ndarray:
"""Get the children of a node in the chunked graph hierarchy.
Parameters
Expand All @@ -663,7 +668,7 @@ def get_children(self, node_id):
response = self.session.get(url)
return np.array(handle_response(response)["children_ids"], dtype=np.int64)

def get_contact_sites(self, root_id, bounds, calc_partners=False):
def get_contact_sites(self, root_id, bounds, calc_partners=False) -> dict:
"""Get contacts for a root ID.
Parameters
Expand Down Expand Up @@ -692,7 +697,9 @@ def get_contact_sites(self, root_id, bounds, calc_partners=False):
contact_d = handle_response(response)
return {int(k): v for k, v in contact_d.items()}

def find_path(self, root_id, src_pt, dst_pt, precision_mode=False):
def find_path(
self, root_id, src_pt, dst_pt, precision_mode=False
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Find a path between two locations on a root ID using the level 2 chunked
graph.
Expand Down Expand Up @@ -740,7 +747,9 @@ def find_path(self, root_id, src_pt, dst_pt, precision_mode=False):

return centroids, l2_path, failed_l2_ids

def get_subgraph(self, root_id, bounds):
def get_subgraph(
self, root_id, bounds
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get subgraph of root id within a bounding box.
Parameters
Expand Down Expand Up @@ -771,7 +780,7 @@ def get_subgraph(self, root_id, bounds):
rd = handle_response(response)
return np.int64(rd["nodes"]), np.double(rd["affinities"]), np.int32(rd["areas"])

def level2_chunk_graph(self, root_id):
def level2_chunk_graph(self, root_id) -> list:
"""
Get graph of level 2 chunks, the smallest agglomeration level above supervoxels.
Expand All @@ -793,7 +802,7 @@ def level2_chunk_graph(self, root_id):
r = handle_response(self.session.get(url))
return r["edge_graph"]

def remesh_level2_chunks(self, chunk_ids):
def remesh_level2_chunks(self, chunk_ids) -> None:
"""Submit specific level 2 chunks to be remeshed in case of a problem.
Parameters
Expand All @@ -808,7 +817,7 @@ def remesh_level2_chunks(self, chunk_ids):
r = self.session.post(url, json=data)
r.raise_for_status()

def get_operation_details(self, operation_ids: Iterable[int]):
def get_operation_details(self, operation_ids: Iterable[int]) -> dict:
"""Get the details of a list of operations.
Parameters
Expand All @@ -824,8 +833,8 @@ def get_operation_details(self, operation_ids: Iterable[int]):
dictionaries contain the following keys:
"added_edges"/"removed_edges": list of list of int
List of edges added (if a merge) or removed (if a split) by this
operation. Each edge is a list of two supervoxel IDs (source and
List of edges added (if a merge) or removed (if a split) by this
operation. Each edge is a list of two supervoxel IDs (source and
target).
"roots": list of int
List of root IDs that were created by this operation.
Expand Down Expand Up @@ -864,7 +873,7 @@ def get_lineage_graph(
as_nx_graph=False,
exclude_links_to_future=False,
exclude_links_to_past=False,
):
) -> Union[dict, nx.DiGraph]:
"""
Returns the lineage graph for a root ID, optionally cut off in the past or
the future.
Expand Down Expand Up @@ -961,7 +970,9 @@ def get_lineage_graph(
else:
return r

def get_latest_roots(self, root_id, timestamp=None, timestamp_future=None):
def get_latest_roots(
self, root_id, timestamp=None, timestamp_future=None
) -> np.ndarray:
"""
Returns root IDs that are related to the given `root_id` at a given
timestamp. Can be used to find the "latest" root IDs associated with an object.
Expand Down Expand Up @@ -1023,7 +1034,7 @@ def get_latest_roots(self, root_id, timestamp=None, timestamp_future=None):
in_degrees = np.array(list(in_degree_dict.values()))
return nodes[in_degrees == 0]

def get_original_roots(self, root_id, timestamp_past=None):
def get_original_roots(self, root_id, timestamp_past=None) -> np.ndarray:
"""Returns root IDs that are the latest successors of a given root ID.
Parameters
Expand Down Expand Up @@ -1054,7 +1065,7 @@ def get_original_roots(self, root_id, timestamp_past=None):
in_degrees = np.array(list(in_degree_dict.values()))
return nodes[in_degrees == 0]

def is_latest_roots(self, root_ids, timestamp=None):
def is_latest_roots(self, root_ids, timestamp=None) -> np.ndarray:
"""Check whether these root IDs are still a root at this timestamp.
Parameters
Expand Down Expand Up @@ -1177,7 +1188,9 @@ def suggest_latest_roots(
else:
return curr_ids[order]

def is_valid_nodes(self, node_ids, start_timestamp=None, end_timestamp=None):
def is_valid_nodes(
self, node_ids, start_timestamp=None, end_timestamp=None
) -> np.ndarray:
"""Check whether nodes are valid for given timestamp range.
Valid is defined as existing in the chunked graph. This makes no statement
Expand Down Expand Up @@ -1236,7 +1249,7 @@ def is_valid_nodes(self, node_ids, start_timestamp=None, end_timestamp=None):

return np.isin(node_ids, valid_ids)

def get_root_timestamps(self, root_ids):
def get_root_timestamps(self, root_ids) -> np.ndarray:
"""Retrieves timestamps when roots where created.
Parameters
Expand All @@ -1263,7 +1276,9 @@ def get_root_timestamps(self, root_ids):
[datetime.datetime.fromtimestamp(ts, pytz.UTC) for ts in r["timestamp"]]
)

def get_past_ids(self, root_ids, timestamp_past=None, timestamp_future=None):
def get_past_ids(
self, root_ids, timestamp_past=None, timestamp_future=None
) -> dict:
"""
For a set of root IDs, get the list of IDs at a past or future time point
that could contain parts of the same object.
Expand Down Expand Up @@ -1320,7 +1335,7 @@ def get_delta_roots(
timestamp_future: datetime.datetime = datetime.datetime.now(
datetime.timezone.utc
),
):
) -> tuple[np.ndarray, np.ndarray]:
"""
Get the list of roots that have changed between `timetamp_past` and
`timestamp_future`.
Expand Down Expand Up @@ -1350,7 +1365,7 @@ def get_delta_roots(
r = handle_response(self.session.get(url, params=params))
return np.array(r["old_roots"]), np.array(r["new_roots"])

def get_oldest_timestamp(self):
def get_oldest_timestamp(self) -> datetime.datetime:
"""Get the oldest timestamp in the database.
Returns
Expand Down

0 comments on commit 0e90f94

Please sign in to comment.