Skip to content

Commit

Permalink
Merge pull request #295 from CAVEconnectome/skeleton_dev
Browse files Browse the repository at this point in the history
SkeletonClient.get_skeleton() 'dict' output format now offers numpy a…
  • Loading branch information
kebwi authored Jan 27, 2025
2 parents fb062c8 + 00e7c3c commit dc6f96d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 14 deletions.
15 changes: 14 additions & 1 deletion caveclient/skeletonservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,20 @@ def get_skeleton(
return sk_json
if endpoint_format == "flatdict":
assert self._server_version >= Version("0.6.0")
return SkeletonClient.decompressBytesToDict(response.content)
sk_json = SkeletonClient.decompressBytesToDict(response.content)
if "edges" in sk_json.keys():
sk_json["edges"] = np.array(sk_json["edges"])
if "mesh_to_skel_map" in sk_json.keys():
sk_json["mesh_to_skel_map"] = np.array(sk_json["mesh_to_skel_map"])
if "vertices" in sk_json.keys():
sk_json["vertices"] = np.array(sk_json["vertices"])
if "lvl2_ids" in sk_json.keys():
sk_json["lvl2_ids"] = np.array(sk_json["lvl2_ids"])
if "radius" in sk_json.keys():
sk_json["radius"] = np.array(sk_json["radius"])
if "compartment" in sk_json.keys():
sk_json["compartment"] = np.array(sk_json["compartment"])
return sk_json
if endpoint_format == "swccompressed":
file_content = SkeletonClient.decompressBytesToString(response.content)

Expand Down
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ docs = [
"pymdown-extensions",
]
lint = ["ruff"]
test = ["pytest", "pytest-cov", "pytest-env", "pytest-mock", "responses"]
test = [
"deepdiff>=8.1.1",
"pytest",
"pytest-cov",
"pytest-env",
"pytest-mock",
"responses",
]

[tool.bumpversion]
allow_dirty = false
Expand Down
33 changes: 22 additions & 11 deletions tests/test_skeletons.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import copy

import deepdiff
import numpy as np
import responses
from packaging.version import Version

Expand Down Expand Up @@ -184,8 +188,8 @@ def test_get_skeleton(self, myclient, mocker):
},
"edges": [
[
0,
1
1,
0
],
],
"mesh_to_skel_map": [
Expand All @@ -195,25 +199,32 @@ def test_get_skeleton(self, myclient, mocker):
"root": 0,
"vertices": [
[
971832,
842176,
906480
1054848., 827272., 601920.
],
[
972568,
842920,
905920
1054856., 827192., 601920.
],
],
"compartment": [
3,
3
],
"radius": [
237.11754897434668,
237.11754897434668
203.6853403, 203.6853403
],
'lvl2_ids': [
173056326983745934, 173126695727923522
]
}

sk_result = copy.deepcopy(sk)
sk_result["edges"] = np.array(sk_result["edges"])
sk_result["mesh_to_skel_map"] = np.array(sk_result["mesh_to_skel_map"])
sk_result["vertices"] = np.array(sk_result["vertices"])
sk_result["lvl2_ids"] = np.array(sk_result["lvl2_ids"])
sk_result["radius"] = np.array(sk_result["radius"])
sk_result["compartment"] = np.array(sk_result["compartment"])

dict_bytes = SkeletonClient.compressDictToBytes(sk)
responses.add(responses.GET, url=metadata_url, body=dict_bytes, status=200)

Expand All @@ -223,4 +234,4 @@ def test_get_skeleton(self, myclient, mocker):
)

result = myclient.skeleton.get_skeleton(0, None, 4, "dict")
assert result == sk
assert not deepdiff.DeepDiff(result, sk_result)
29 changes: 28 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit dc6f96d

Please sign in to comment.