Skip to content

Add return type hints to chunkedgraph #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions caveclient/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
default_global_server_address,
)


SERVER_KEY = "cg_server_address"

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -202,7 +201,7 @@ def _process_timestamp(self, timestamp):
else:
return timestamp

def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None):
def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None) -> np.ndarray:
"""Get the root ID for a list of supervoxels.

Parameters
Expand Down Expand Up @@ -231,7 +230,7 @@ def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None):
handle_response(response, as_json=False)
return np.frombuffer(response.content, dtype=np.uint64)

def get_root_id(self, supervoxel_id, timestamp=None, level2=False):
def get_root_id(self, supervoxel_id, timestamp=None, level2=False) -> np.int64:
"""Get the root ID for a specified supervoxel.

Parameters
Expand All @@ -258,7 +257,7 @@ def get_root_id(self, supervoxel_id, timestamp=None, level2=False):
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response, as_json=True)["root_id"])

def get_merge_log(self, root_id):
def get_merge_log(self, root_id) -> list:
"""Get the merge log (splits and merges) for an object.

Parameters
Expand All @@ -278,7 +277,7 @@ def get_merge_log(self, root_id):
response = self.session.get(url)
return handle_response(response)

def get_change_log(self, root_id, filtered=True):
def get_change_log(self, root_id, filtered=True) -> dict:
"""Get the change log (splits and merges) for an object.

Parameters
Expand Down Expand Up @@ -323,7 +322,7 @@ def get_user_operations(
timestamp_start: datetime.datetime,
include_undo: bool = True,
timestamp_end: datetime.datetime = None,
):
) -> pd.DataFrame:
"""
Get operation details for a user ID. Currently, this is only available to
admins.
Expand Down Expand Up @@ -379,7 +378,7 @@ def get_user_operations(
)
return df

def get_tabular_change_log(self, root_ids, filtered=True):
def get_tabular_change_log(self, root_ids, filtered=True) -> dict:
"""Get a detailed changelog for neurons.

Parameters
Expand Down Expand Up @@ -435,7 +434,7 @@ def get_tabular_change_log(self, root_ids, filtered=True):

return changelog_dict

def get_leaves(self, root_id, bounds=None, stop_layer: int = None):
def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray:
"""Get all supervoxels for a root ID.

Parameters
Expand Down Expand Up @@ -465,7 +464,7 @@ def get_leaves(self, root_id, bounds=None, stop_layer: int = None):
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response)["leaf_ids"])

def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)):
def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)) -> None:
"""Perform a merge on the chunked graph.

Parameters
Expand Down Expand Up @@ -495,7 +494,7 @@ def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)):
)
handle_response(response)

def undo_operation(self, operation_id):
def undo_operation(self, operation_id) -> dict:
"""Undo an operation.

Parameters
Expand Down Expand Up @@ -529,7 +528,7 @@ def execute_split(
root_id,
source_supervoxels=None,
sink_supervoxels=None,
):
) -> tuple[int, list]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be improper syntax in our testing environment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fcollman what python versions are supposed to be supported? I can't actually find that information anywhere

I think this must be a python version thing; this works for me on 3.11. i have a vague memory of Tuple changing to tuple or something like that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend merging the above first and rerunning tests for this one on all versions, but I think this should be fixed now

"""Execute a multicut split based on points or supervoxels.

Parameters
Expand Down Expand Up @@ -580,7 +579,7 @@ def preview_split(
source_supervoxels=None,
sink_supervoxels=None,
return_additional_ccs=False,
):
) -> tuple[list, list, bool, list]:
"""Get supervoxel connected components from a preview multicut split.

Parameters
Expand Down Expand Up @@ -643,7 +642,7 @@ def preview_split(
else:
return source_cc, sink_cc, success

def get_children(self, node_id):
def get_children(self, node_id) -> np.ndarray:
"""Get the children of a node in the chunked graph hierarchy.

Parameters
Expand All @@ -663,7 +662,7 @@ def get_children(self, node_id):
response = self.session.get(url)
return np.array(handle_response(response)["children_ids"], dtype=np.int64)

def get_contact_sites(self, root_id, bounds, calc_partners=False):
def get_contact_sites(self, root_id, bounds, calc_partners=False) -> dict:
"""Get contacts for a root ID.

Parameters
Expand Down Expand Up @@ -692,7 +691,9 @@ def get_contact_sites(self, root_id, bounds, calc_partners=False):
contact_d = handle_response(response)
return {int(k): v for k, v in contact_d.items()}

def find_path(self, root_id, src_pt, dst_pt, precision_mode=False):
def find_path(
self, root_id, src_pt, dst_pt, precision_mode=False
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Find a path between two locations on a root ID using the level 2 chunked
graph.
Expand Down Expand Up @@ -740,7 +741,9 @@ def find_path(self, root_id, src_pt, dst_pt, precision_mode=False):

return centroids, l2_path, failed_l2_ids

def get_subgraph(self, root_id, bounds):
def get_subgraph(
self, root_id, bounds
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get subgraph of root id within a bounding box.

Parameters
Expand Down Expand Up @@ -771,7 +774,7 @@ def get_subgraph(self, root_id, bounds):
rd = handle_response(response)
return np.int64(rd["nodes"]), np.double(rd["affinities"]), np.int32(rd["areas"])

def level2_chunk_graph(self, root_id):
def level2_chunk_graph(self, root_id) -> list:
"""
Get graph of level 2 chunks, the smallest agglomeration level above supervoxels.

Expand All @@ -793,7 +796,7 @@ def level2_chunk_graph(self, root_id):
r = handle_response(self.session.get(url))
return r["edge_graph"]

def remesh_level2_chunks(self, chunk_ids):
def remesh_level2_chunks(self, chunk_ids) -> None:
"""Submit specific level 2 chunks to be remeshed in case of a problem.

Parameters
Expand All @@ -808,7 +811,7 @@ def remesh_level2_chunks(self, chunk_ids):
r = self.session.post(url, json=data)
r.raise_for_status()

def get_operation_details(self, operation_ids: Iterable[int]):
def get_operation_details(self, operation_ids: Iterable[int]) -> dict:
"""Get the details of a list of operations.

Parameters
Expand All @@ -824,8 +827,8 @@ def get_operation_details(self, operation_ids: Iterable[int]):
dictionaries contain the following keys:

"added_edges"/"removed_edges": list of list of int
List of edges added (if a merge) or removed (if a split) by this
operation. Each edge is a list of two supervoxel IDs (source and
List of edges added (if a merge) or removed (if a split) by this
operation. Each edge is a list of two supervoxel IDs (source and
target).
"roots": list of int
List of root IDs that were created by this operation.
Expand Down Expand Up @@ -864,7 +867,7 @@ def get_lineage_graph(
as_nx_graph=False,
exclude_links_to_future=False,
exclude_links_to_past=False,
):
) -> dict | nx.DiGraph:
"""
Returns the lineage graph for a root ID, optionally cut off in the past or
the future.
Expand Down Expand Up @@ -961,7 +964,9 @@ def get_lineage_graph(
else:
return r

def get_latest_roots(self, root_id, timestamp=None, timestamp_future=None):
def get_latest_roots(
self, root_id, timestamp=None, timestamp_future=None
) -> np.ndarray:
"""
Returns root IDs that are related to the given `root_id` at a given
timestamp. Can be used to find the "latest" root IDs associated with an object.
Expand Down Expand Up @@ -1023,7 +1028,7 @@ def get_latest_roots(self, root_id, timestamp=None, timestamp_future=None):
in_degrees = np.array(list(in_degree_dict.values()))
return nodes[in_degrees == 0]

def get_original_roots(self, root_id, timestamp_past=None):
def get_original_roots(self, root_id, timestamp_past=None) -> np.ndarray:
"""Returns root IDs that are the latest successors of a given root ID.

Parameters
Expand Down Expand Up @@ -1054,7 +1059,7 @@ def get_original_roots(self, root_id, timestamp_past=None):
in_degrees = np.array(list(in_degree_dict.values()))
return nodes[in_degrees == 0]

def is_latest_roots(self, root_ids, timestamp=None):
def is_latest_roots(self, root_ids, timestamp=None) -> np.ndarray:
"""Check whether these root IDs are still a root at this timestamp.

Parameters
Expand Down Expand Up @@ -1177,7 +1182,9 @@ def suggest_latest_roots(
else:
return curr_ids[order]

def is_valid_nodes(self, node_ids, start_timestamp=None, end_timestamp=None):
def is_valid_nodes(
self, node_ids, start_timestamp=None, end_timestamp=None
) -> np.ndarray:
"""Check whether nodes are valid for given timestamp range.

Valid is defined as existing in the chunked graph. This makes no statement
Expand Down Expand Up @@ -1236,7 +1243,7 @@ def is_valid_nodes(self, node_ids, start_timestamp=None, end_timestamp=None):

return np.isin(node_ids, valid_ids)

def get_root_timestamps(self, root_ids):
def get_root_timestamps(self, root_ids) -> np.ndarray:
"""Retrieves timestamps when roots where created.

Parameters
Expand All @@ -1263,7 +1270,9 @@ def get_root_timestamps(self, root_ids):
[datetime.datetime.fromtimestamp(ts, pytz.UTC) for ts in r["timestamp"]]
)

def get_past_ids(self, root_ids, timestamp_past=None, timestamp_future=None):
def get_past_ids(
self, root_ids, timestamp_past=None, timestamp_future=None
) -> dict:
"""
For a set of root IDs, get the list of IDs at a past or future time point
that could contain parts of the same object.
Expand Down Expand Up @@ -1320,7 +1329,7 @@ def get_delta_roots(
timestamp_future: datetime.datetime = datetime.datetime.now(
datetime.timezone.utc
),
):
) -> tuple[np.ndarray, np.ndarray]:
"""
Get the list of roots that have changed between `timetamp_past` and
`timestamp_future`.
Expand Down Expand Up @@ -1350,7 +1359,7 @@ def get_delta_roots(
r = handle_response(self.session.get(url, params=params))
return np.array(r["old_roots"]), np.array(r["new_roots"])

def get_oldest_timestamp(self):
def get_oldest_timestamp(self) -> datetime.datetime:
"""Get the oldest timestamp in the database.

Returns
Expand Down