Skip to content

Commit

Permalink
Set actionTracingId to tracingId for editableMappingUpdates (#8361)
Browse files Browse the repository at this point in the history
* Set actionTracingId to tracingId for editableMappingUpdates

* fix sql syntax

* include annotation id

* wip repair update actions script

* wip fetch relevant updates

* iterate on repariring updates

* put updated updates

* skip the reverse, the order in which we deal with the update groups doesnt matter here

* undo application.conf change
  • Loading branch information
fm3 authored Jan 30, 2025
1 parent 014a8e8 commit 2171b8f
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 6 deletions.
1 change: 1 addition & 0 deletions tools/migration-unified-annotation-versioning/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ counts.py
logs/
*.dat
result.json
mapping_tracing_mapping.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
from utils import setup_logging, log_since
import argparse
from connections import connect_to_fossildb, connect_to_postgres, assert_grpc_success
import psycopg2
import psycopg2.extras
import time
import fossildbapi_pb2 as proto
import VolumeTracing_pb2 as Volume
from typing import Optional
import msgspec

logger = logging.getLogger(__name__)


def main():
logger.info("Hello from find_mapping_tracing_mapping")
setup_logging()
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, help="Source fossildb host and port. Example: localhost:7155", required=True)
parser.add_argument("--postgres", help="Postgres connection specifier, default is postgresql://postgres@localhost:5432/webknossos", type=str, default="postgresql://postgres@localhost:5432/webknossos")
args = parser.parse_args()
before = time.time()
annotations = read_annotation_list(args)
src_stub = connect_to_fossildb(args.src, "source")
mappings = {}
for annotation in annotations:
annotation_id = annotation["_id"]
id_mapping_for_annotation = {}
for tracing_id, layer_type in annotation["layers"].items():
if layer_type == 'Volume':
try:
editable_mapping_id = get_editable_mapping_id(src_stub, tracing_id, layer_type)
if editable_mapping_id is not None:
id_mapping_for_annotation[editable_mapping_id] = tracing_id
except Exception as e:
logger.info(f"exception while checking layer {tracing_id} of {annotation_id}: {e}")
if id_mapping_for_annotation:
mappings[annotation_id] = id_mapping_for_annotation

outfile_name = "mapping_tracing_mapping.json"
logger.info(f"Writing mapping to {outfile_name}...")
with open(outfile_name, "wb") as outfile:
outfile.write(msgspec.json.encode(mappings))

log_since(before, f"Wrote full id mapping to {outfile_name}. Checked {len(annotations)} annotations, wrote {len(mappings)} annotation id mappings.")


def get_newest_tracing_raw(src_stub, tracing_id, collection) -> Optional[bytes]:
getReply = src_stub.Get(
proto.GetRequest(collection=collection, key=tracing_id, mayBeEmpty=True)
)
assert_grpc_success(getReply)
return getReply.value


def get_editable_mapping_id(src_stub, tracing_id: str, layer_type: str) -> Optional[str]:
if layer_type == "Skeleton":
return None
tracing_raw = get_newest_tracing_raw(src_stub, tracing_id, "volumes")
if tracing_raw is None:
return None
volume = Volume.VolumeTracing()
volume.ParseFromString(tracing_raw)
if volume.hasEditableMapping:
return volume.mappingName
return None


def read_annotation_list(args):
before = time.time()
connection = connect_to_postgres(args.postgres)
cursor = connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cursor.execute(f"SELECT COUNT(*) FROM webknossos.annotations")
annotation_count = cursor.fetchone()['count']
logger.info(f"Loading infos of {annotation_count} annotations from postgres ...")
query = f"""SELECT
a._id,
JSON_OBJECT_AGG(al.tracingId, al.typ) AS layers,
JSON_OBJECT_AGG(al.tracingId, al.name) AS layerNames
FROM webknossos.annotation_layers al
JOIN webknossos.annotations a on al._annotation = a._id
GROUP BY a._id
"""
cursor.execute(query)
annotations = cursor.fetchall()
log_since(before, "Loading annotation infos from postgres")
return annotations

if __name__ == '__main__':
main()
12 changes: 6 additions & 6 deletions tools/migration-unified-annotation-versioning/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def build_mapping_id_map(self, annotation) -> MappingIdMap:
mapping_id_map[tracing_id] = editable_mapping_id
return mapping_id_map

def fetch_updates(self, tracing_or_mapping_id: str, layer_type: str, collection: str, json_encoder, json_decoder) -> Tuple[List[Tuple[int, int, bytes]], bool]:
def fetch_updates(self, tracing_id: str, tracing_or_mapping_id: str, layer_type: str, collection: str, json_encoder, json_decoder) -> Tuple[List[Tuple[int, int, bytes]], bool]:
batch_size = 100
newest_version = self.get_newest_version(tracing_or_mapping_id, collection)
updates_for_layer = []
Expand All @@ -118,7 +118,7 @@ def fetch_updates(self, tracing_or_mapping_id: str, layer_type: str, collection:
for version, update_group in reversed(update_groups):
if version > next_version:
continue
update_group, timestamp, revert_source_version = self.process_update_group(tracing_or_mapping_id, layer_type, update_group, json_encoder, json_decoder)
update_group, timestamp, revert_source_version = self.process_update_group(tracing_id, layer_type, update_group, json_encoder, json_decoder)
if revert_source_version is not None:
next_version = revert_source_version
included_revert = True
Expand All @@ -135,7 +135,7 @@ def includes_revert(self, annotation) -> bool:
layers = list(annotation["layers"].items())
for tracing_id, layer_type in layers:
collection = self.update_collection_for_layer_type(layer_type)
_, layer_included_revert = self.fetch_updates(tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder)
_, layer_included_revert = self.fetch_updates(tracing_id, tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder)
if layer_included_revert:
return True
return False
Expand All @@ -148,12 +148,12 @@ def migrate_updates(self, annotation, mapping_id_map: MappingIdMap) -> LayerVers
tracing_ids_and_mapping_ids = []
for tracing_id, layer_type in layers:
collection = self.update_collection_for_layer_type(layer_type)
layer_updates, _ = self.fetch_updates(tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder)
layer_updates, _ = self.fetch_updates(tracing_id, tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder)
all_update_groups.append(layer_updates)
tracing_ids_and_mapping_ids.append(tracing_id)
if tracing_id in mapping_id_map:
mapping_id = mapping_id_map[tracing_id]
layer_updates, _ = self.fetch_updates(mapping_id, "editableMapping", "editableMappingUpdates", json_encoder=json_encoder, json_decoder=json_decoder)
layer_updates, _ = self.fetch_updates(tracing_id, mapping_id, "editableMapping", "editableMappingUpdates", json_encoder=json_encoder, json_decoder=json_decoder)
all_update_groups.append(layer_updates)
tracing_ids_and_mapping_ids.append(mapping_id)

Expand Down Expand Up @@ -239,7 +239,7 @@ def process_update_group(self, tracing_id: str, layer_type: str, update_group_ra

# add actionTracingId
if not name == "updateTdCamera":
update["value"]["actionTracingId"] = tracing_id
update["value"]["actionTracingId"] = tracing_id # even for mappings, this is the tracing_id of their corresponding volume layer

# identify compact update actions, and mark them
if (name == "updateBucket" and "position" not in update_value) \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
from utils import setup_logging, log_since, batch_range
import argparse
from connections import connect_to_fossildb, assert_grpc_success
import time
import fossildbapi_pb2 as proto
from typing import List, Tuple
import msgspec

logger = logging.getLogger("migration-logs")


def main():
logger.info("Hello from repair_editable_mapping_updates")
setup_logging()
parser = argparse.ArgumentParser()
parser.add_argument("--fossil", type=str, help="Fossildb host and port. Example: localhost:7155", required=True)
parser.add_argument("--id_mapping", type=str, help="json file containing the id mapping determined by find_mapping_tracing_mapping.py", required=True)
args = parser.parse_args()
before = time.time()
stub = connect_to_fossildb(args.fossil, "target")

json_encoder = msgspec.json.Encoder()
json_decoder = msgspec.json.Decoder()
with open(args.id_mapping, "rb") as infile:
id_mapping = json_decoder.decode(infile.read())
for annotation_id in id_mapping.keys():
repair_updates_of_annotation(stub, annotation_id, id_mapping[annotation_id], json_encoder, json_decoder)

log_since(before, f"Repairing all {len(id_mapping)} annotations")


def repair_updates_of_annotation(stub, annotation_id, id_mapping_for_annotation, json_encoder, json_decoder):
get_batch_size = 100 # in update groups
put_buffer_size = 100 # in update groups

before = time.time()
put_buffer = []
changed_update_count = 0
newest_version = get_newest_version(stub, annotation_id, "annotationUpdates")
if newest_version > 10000:
logger.info(f"Newest version of {annotation_id} is {newest_version}. This may take some time...")
for batch_start, batch_end in list(batch_range(newest_version + 1, get_batch_size)):
update_groups_batch = get_update_batch(stub, annotation_id, batch_start, batch_end - 1)
for version, update_group_bytes in update_groups_batch:
update_group = json_decoder.decode(update_group_bytes)
group_changed = False
for update in update_group:
if "value" in update:
update_value = update["value"]
if "actionTracingId" in update_value and update_value["actionTracingId"] in id_mapping_for_annotation:
update_value["actionTracingId"] = id_mapping_for_annotation[update_value["actionTracingId"]]
group_changed = True
changed_update_count += 1
if group_changed:
versioned_key_value_pair = proto.VersionedKeyValuePairProto()
versioned_key_value_pair.key = annotation_id
versioned_key_value_pair.version = version
versioned_key_value_pair.value = json_encoder.encode(update_group)
put_buffer.append(versioned_key_value_pair)
if len(put_buffer) >= put_buffer_size:
put_multiple_keys_versions(stub, "annotationUpdates", put_buffer)
put_buffer = []
if len(put_buffer) > 0:
put_multiple_keys_versions(stub, "annotationUpdates", put_buffer)
log_since(before, f"Repaired {changed_update_count} updates of annotation {annotation_id},")


def put_multiple_keys_versions(stub, collection: str, to_put) -> None:
reply = stub.PutMultipleKeysWithMultipleVersions(proto.PutMultipleKeysWithMultipleVersionsRequest(collection=collection, versionedKeyValuePairs = to_put))
assert_grpc_success(reply)


def get_update_batch(stub, annotation_id: str, batch_start: int, batch_end_inclusive: int) -> List[Tuple[int, bytes]]:
reply = stub.GetMultipleVersions(
proto.GetMultipleVersionsRequest(collection="annotationUpdates", key=annotation_id, oldestVersion=batch_start, newestVersion=batch_end_inclusive)
)
assert_grpc_success(reply)
return list(zip(reply.versions, reply.values))


def get_newest_version(stub, tracing_id: str, collection: str) -> int:
reply = stub.Get(
proto.GetRequest(collection=collection, key=tracing_id, mayBeEmpty=True)
)
assert_grpc_success(reply)
return reply.actualVersion


if __name__ == '__main__':
main()

0 comments on commit 2171b8f

Please sign in to comment.