From 2171b8f36bd2356fe30088ca241a08950b9bfb84 Mon Sep 17 00:00:00 2001 From: Florian M Date: Thu, 30 Jan 2025 12:34:35 +0100 Subject: [PATCH] Set actionTracingId to tracingId for editableMappingUpdates (#8361) * Set actionTracingId to tracingId for editableMappingUpdates * fix sql syntax * include annotation id * wip repair update actions script * wip fetch relevant updates * iterate on repariring updates * put updated updates * skip the reverse, the order in which we deal with the update groups doesnt matter here * undo application.conf change --- .../.gitignore | 1 + .../find_mapping_tracing_mapping.py | 91 +++++++++++++++++++ .../migration.py | 12 +-- .../repair_editable_mapping_updates.py | 91 +++++++++++++++++++ 4 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 tools/migration-unified-annotation-versioning/find_mapping_tracing_mapping.py create mode 100644 tools/migration-unified-annotation-versioning/repair_editable_mapping_updates.py diff --git a/tools/migration-unified-annotation-versioning/.gitignore b/tools/migration-unified-annotation-versioning/.gitignore index e7f2901fb60..1b17bf64243 100644 --- a/tools/migration-unified-annotation-versioning/.gitignore +++ b/tools/migration-unified-annotation-versioning/.gitignore @@ -4,3 +4,4 @@ counts.py logs/ *.dat result.json +mapping_tracing_mapping.json diff --git a/tools/migration-unified-annotation-versioning/find_mapping_tracing_mapping.py b/tools/migration-unified-annotation-versioning/find_mapping_tracing_mapping.py new file mode 100644 index 00000000000..1288d39f612 --- /dev/null +++ b/tools/migration-unified-annotation-versioning/find_mapping_tracing_mapping.py @@ -0,0 +1,91 @@ +import logging +from utils import setup_logging, log_since +import argparse +from connections import connect_to_fossildb, connect_to_postgres, assert_grpc_success +import psycopg2 +import psycopg2.extras +import time +import fossildbapi_pb2 as proto +import VolumeTracing_pb2 as Volume +from typing import Optional +import msgspec + +logger = logging.getLogger(__name__) + + +def main(): + logger.info("Hello from find_mapping_tracing_mapping") + setup_logging() + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, help="Source fossildb host and port. Example: localhost:7155", required=True) + parser.add_argument("--postgres", help="Postgres connection specifier, default is postgresql://postgres@localhost:5432/webknossos", type=str, default="postgresql://postgres@localhost:5432/webknossos") + args = parser.parse_args() + before = time.time() + annotations = read_annotation_list(args) + src_stub = connect_to_fossildb(args.src, "source") + mappings = {} + for annotation in annotations: + annotation_id = annotation["_id"] + id_mapping_for_annotation = {} + for tracing_id, layer_type in annotation["layers"].items(): + if layer_type == 'Volume': + try: + editable_mapping_id = get_editable_mapping_id(src_stub, tracing_id, layer_type) + if editable_mapping_id is not None: + id_mapping_for_annotation[editable_mapping_id] = tracing_id + except Exception as e: + logger.info(f"exception while checking layer {tracing_id} of {annotation_id}: {e}") + if id_mapping_for_annotation: + mappings[annotation_id] = id_mapping_for_annotation + + outfile_name = "mapping_tracing_mapping.json" + logger.info(f"Writing mapping to {outfile_name}...") + with open(outfile_name, "wb") as outfile: + outfile.write(msgspec.json.encode(mappings)) + + log_since(before, f"Wrote full id mapping to {outfile_name}. Checked {len(annotations)} annotations, wrote {len(mappings)} annotation id mappings.") + + +def get_newest_tracing_raw(src_stub, tracing_id, collection) -> Optional[bytes]: + getReply = src_stub.Get( + proto.GetRequest(collection=collection, key=tracing_id, mayBeEmpty=True) + ) + assert_grpc_success(getReply) + return getReply.value + + +def get_editable_mapping_id(src_stub, tracing_id: str, layer_type: str) -> Optional[str]: + if layer_type == "Skeleton": + return None + tracing_raw = get_newest_tracing_raw(src_stub, tracing_id, "volumes") + if tracing_raw is None: + return None + volume = Volume.VolumeTracing() + volume.ParseFromString(tracing_raw) + if volume.hasEditableMapping: + return volume.mappingName + return None + + +def read_annotation_list(args): + before = time.time() + connection = connect_to_postgres(args.postgres) + cursor = connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + cursor.execute(f"SELECT COUNT(*) FROM webknossos.annotations") + annotation_count = cursor.fetchone()['count'] + logger.info(f"Loading infos of {annotation_count} annotations from postgres ...") + query = f"""SELECT + a._id, + JSON_OBJECT_AGG(al.tracingId, al.typ) AS layers, + JSON_OBJECT_AGG(al.tracingId, al.name) AS layerNames + FROM webknossos.annotation_layers al + JOIN webknossos.annotations a on al._annotation = a._id + GROUP BY a._id + """ + cursor.execute(query) + annotations = cursor.fetchall() + log_since(before, "Loading annotation infos from postgres") + return annotations + +if __name__ == '__main__': + main() diff --git a/tools/migration-unified-annotation-versioning/migration.py b/tools/migration-unified-annotation-versioning/migration.py index 209611e3c7b..d8122a4f77e 100644 --- a/tools/migration-unified-annotation-versioning/migration.py +++ b/tools/migration-unified-annotation-versioning/migration.py @@ -105,7 +105,7 @@ def build_mapping_id_map(self, annotation) -> MappingIdMap: mapping_id_map[tracing_id] = editable_mapping_id return mapping_id_map - def fetch_updates(self, tracing_or_mapping_id: str, layer_type: str, collection: str, json_encoder, json_decoder) -> Tuple[List[Tuple[int, int, bytes]], bool]: + def fetch_updates(self, tracing_id: str, tracing_or_mapping_id: str, layer_type: str, collection: str, json_encoder, json_decoder) -> Tuple[List[Tuple[int, int, bytes]], bool]: batch_size = 100 newest_version = self.get_newest_version(tracing_or_mapping_id, collection) updates_for_layer = [] @@ -118,7 +118,7 @@ def fetch_updates(self, tracing_or_mapping_id: str, layer_type: str, collection: for version, update_group in reversed(update_groups): if version > next_version: continue - update_group, timestamp, revert_source_version = self.process_update_group(tracing_or_mapping_id, layer_type, update_group, json_encoder, json_decoder) + update_group, timestamp, revert_source_version = self.process_update_group(tracing_id, layer_type, update_group, json_encoder, json_decoder) if revert_source_version is not None: next_version = revert_source_version included_revert = True @@ -135,7 +135,7 @@ def includes_revert(self, annotation) -> bool: layers = list(annotation["layers"].items()) for tracing_id, layer_type in layers: collection = self.update_collection_for_layer_type(layer_type) - _, layer_included_revert = self.fetch_updates(tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder) + _, layer_included_revert = self.fetch_updates(tracing_id, tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder) if layer_included_revert: return True return False @@ -148,12 +148,12 @@ def migrate_updates(self, annotation, mapping_id_map: MappingIdMap) -> LayerVers tracing_ids_and_mapping_ids = [] for tracing_id, layer_type in layers: collection = self.update_collection_for_layer_type(layer_type) - layer_updates, _ = self.fetch_updates(tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder) + layer_updates, _ = self.fetch_updates(tracing_id, tracing_id, layer_type, collection, json_encoder=json_encoder, json_decoder=json_decoder) all_update_groups.append(layer_updates) tracing_ids_and_mapping_ids.append(tracing_id) if tracing_id in mapping_id_map: mapping_id = mapping_id_map[tracing_id] - layer_updates, _ = self.fetch_updates(mapping_id, "editableMapping", "editableMappingUpdates", json_encoder=json_encoder, json_decoder=json_decoder) + layer_updates, _ = self.fetch_updates(tracing_id, mapping_id, "editableMapping", "editableMappingUpdates", json_encoder=json_encoder, json_decoder=json_decoder) all_update_groups.append(layer_updates) tracing_ids_and_mapping_ids.append(mapping_id) @@ -239,7 +239,7 @@ def process_update_group(self, tracing_id: str, layer_type: str, update_group_ra # add actionTracingId if not name == "updateTdCamera": - update["value"]["actionTracingId"] = tracing_id + update["value"]["actionTracingId"] = tracing_id # even for mappings, this is the tracing_id of their corresponding volume layer # identify compact update actions, and mark them if (name == "updateBucket" and "position" not in update_value) \ diff --git a/tools/migration-unified-annotation-versioning/repair_editable_mapping_updates.py b/tools/migration-unified-annotation-versioning/repair_editable_mapping_updates.py new file mode 100644 index 00000000000..3ea99bbaff5 --- /dev/null +++ b/tools/migration-unified-annotation-versioning/repair_editable_mapping_updates.py @@ -0,0 +1,91 @@ +import logging +from utils import setup_logging, log_since, batch_range +import argparse +from connections import connect_to_fossildb, assert_grpc_success +import time +import fossildbapi_pb2 as proto +from typing import List, Tuple +import msgspec + +logger = logging.getLogger("migration-logs") + + +def main(): + logger.info("Hello from repair_editable_mapping_updates") + setup_logging() + parser = argparse.ArgumentParser() + parser.add_argument("--fossil", type=str, help="Fossildb host and port. Example: localhost:7155", required=True) + parser.add_argument("--id_mapping", type=str, help="json file containing the id mapping determined by find_mapping_tracing_mapping.py", required=True) + args = parser.parse_args() + before = time.time() + stub = connect_to_fossildb(args.fossil, "target") + + json_encoder = msgspec.json.Encoder() + json_decoder = msgspec.json.Decoder() + with open(args.id_mapping, "rb") as infile: + id_mapping = json_decoder.decode(infile.read()) + for annotation_id in id_mapping.keys(): + repair_updates_of_annotation(stub, annotation_id, id_mapping[annotation_id], json_encoder, json_decoder) + + log_since(before, f"Repairing all {len(id_mapping)} annotations") + + +def repair_updates_of_annotation(stub, annotation_id, id_mapping_for_annotation, json_encoder, json_decoder): + get_batch_size = 100 # in update groups + put_buffer_size = 100 # in update groups + + before = time.time() + put_buffer = [] + changed_update_count = 0 + newest_version = get_newest_version(stub, annotation_id, "annotationUpdates") + if newest_version > 10000: + logger.info(f"Newest version of {annotation_id} is {newest_version}. This may take some time...") + for batch_start, batch_end in list(batch_range(newest_version + 1, get_batch_size)): + update_groups_batch = get_update_batch(stub, annotation_id, batch_start, batch_end - 1) + for version, update_group_bytes in update_groups_batch: + update_group = json_decoder.decode(update_group_bytes) + group_changed = False + for update in update_group: + if "value" in update: + update_value = update["value"] + if "actionTracingId" in update_value and update_value["actionTracingId"] in id_mapping_for_annotation: + update_value["actionTracingId"] = id_mapping_for_annotation[update_value["actionTracingId"]] + group_changed = True + changed_update_count += 1 + if group_changed: + versioned_key_value_pair = proto.VersionedKeyValuePairProto() + versioned_key_value_pair.key = annotation_id + versioned_key_value_pair.version = version + versioned_key_value_pair.value = json_encoder.encode(update_group) + put_buffer.append(versioned_key_value_pair) + if len(put_buffer) >= put_buffer_size: + put_multiple_keys_versions(stub, "annotationUpdates", put_buffer) + put_buffer = [] + if len(put_buffer) > 0: + put_multiple_keys_versions(stub, "annotationUpdates", put_buffer) + log_since(before, f"Repaired {changed_update_count} updates of annotation {annotation_id},") + + +def put_multiple_keys_versions(stub, collection: str, to_put) -> None: + reply = stub.PutMultipleKeysWithMultipleVersions(proto.PutMultipleKeysWithMultipleVersionsRequest(collection=collection, versionedKeyValuePairs = to_put)) + assert_grpc_success(reply) + + +def get_update_batch(stub, annotation_id: str, batch_start: int, batch_end_inclusive: int) -> List[Tuple[int, bytes]]: + reply = stub.GetMultipleVersions( + proto.GetMultipleVersionsRequest(collection="annotationUpdates", key=annotation_id, oldestVersion=batch_start, newestVersion=batch_end_inclusive) + ) + assert_grpc_success(reply) + return list(zip(reply.versions, reply.values)) + + +def get_newest_version(stub, tracing_id: str, collection: str) -> int: + reply = stub.Get( + proto.GetRequest(collection=collection, key=tracing_id, mayBeEmpty=True) + ) + assert_grpc_success(reply) + return reply.actualVersion + + +if __name__ == '__main__': + main()