Skip to content

Commit

Permalink
feature: added background task for computing object checksum in multi…
Browse files Browse the repository at this point in the history
…part upload
  • Loading branch information
mesemus committed Mar 9, 2025
1 parent ad4e195 commit d321cf5
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 11 deletions.
65 changes: 64 additions & 1 deletion invenio_records_resources/services/files/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@

"""Files tasks."""

import hashlib
import traceback

import requests
from celery import shared_task
from flask import current_app
from invenio_access.permissions import system_identity
from invenio_db import db
from invenio_files_rest.models import FileInstance
from invenio_files_rest.proxies import current_files_rest

from ...proxies import current_service_registry
from ...services.errors import FileKeyNotFoundError
from ..errors import TransferException
from .transfer.constants import LOCAL_TRANSFER_TYPE


Expand Down Expand Up @@ -73,5 +78,63 @@ def fetch_file(service_id, record_id, file_key):
except Exception as e:
current_app.logger.error(e)
traceback.print_exc()
# do not raise an exception as we want the task to be marked as errored
raise


@shared_task(ignore_result=True)
def recompute_multipart_checksum_task(file_instance_id):
"""Create checksum for a single object from multipart upload."""
try:
file_instance = FileInstance.query.filter_by(id=file_instance_id).one()
checksum = file_instance.checksum
if not checksum.startswith("multipart:"):
return
# multipart checksum looks like: multipart:<s3 multipart checksum>-part_size
# s3 multipart checksum is the etag of the multipart object and looks like
# hex(md5(<md5(part1) + md5(part2) + ...>))-<number of parts>
original_checksum_hex, _number_of_parts_str, part_size_str = checksum[10:].rsplit("-")
part_size = int(part_size_str)

storage = current_files_rest.storage_factory(fileinstance=file_instance)
with storage.open("rb") as f:
object_checksum = hashlib.md5()
part_checksums = []
while part_checksum := compute_checksum(f, object_checksum, part_size):
part_checksums.append(part_checksum)
piecewise_checksum = hashlib.md5(b"".join(part_checksums)).hexdigest()

if piecewise_checksum != original_checksum_hex:
raise TransferException(
f"Checksums do not match - recorded checksum: {original_checksum_hex}, "
f"computed checksum: {piecewise_checksum}"
)

file_instance.checksum = "md5:" + object_checksum.hexdigest()
db.session.add(file_instance)
db.session.commit()

except FileKeyNotFoundError as e:
current_app.logger.error(e)
return
except Exception as e:
current_app.logger.error(e)
traceback.print_exc()
raise


def compute_checksum(file_stream, object_checksum, part_size):
"""Compute checksum for a single object from multipart upload."""
buffer_size = min(1024 * 1024, part_size)
bytes_remaining = part_size
part_checksum = hashlib.md5()
while bytes_remaining > 0:
chunk = file_stream.read(min(buffer_size, bytes_remaining))
if not chunk:
break
object_checksum.update(chunk)
part_checksum.update(chunk)
bytes_remaining -= len(chunk)
if bytes_remaining == part_size:
# nothing was read
return None
return part_checksum.digest()
2 changes: 1 addition & 1 deletion invenio_records_resources/services/files/transfer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.record = record
self.key = key
self.file_service = file_service
self._file_record = file_record # need to store it as it might be deleted
self._file_record = file_record # need to store it as it might be deleted
self.uow = uow

def init_file(self, record, file_metadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from datetime import datetime, timedelta

import marshmallow as ma
from invenio_db import db
from invenio_db.uow import ModelCommitOp
from invenio_files_rest import current_files_rest
from invenio_files_rest.models import FileInstance, ObjectVersion

from ....errors import TransferException
from ....uow import RecordCommitOp
from ...schema import BaseTransferSchema
from ...tasks import (
recompute_multipart_checksum_task,
)
from ..base import Transfer, TransferStatus
from ..constants import LOCAL_TRANSFER_TYPE, MULTIPART_TRANSFER_TYPE

Expand Down Expand Up @@ -107,9 +111,10 @@ def multipart_commit_upload(self, **multipart_metadata):
:param multipart_metadata: The metadata returned by the multipart_initialize_upload
and the metadata returned by the multipart_set_content for each part.
:returns: None or a multipart checksum, if it was computed by the backend.
"""
if hasattr(self._storage, "multipart_commit_upload"):
self._storage.multipart_commit_upload(**multipart_metadata)
return self._storage.multipart_commit_upload(**multipart_metadata)

def multipart_abort_upload(self, **multipart_metadata):
"""
Expand Down Expand Up @@ -198,12 +203,12 @@ def init_file(self, record, file_metadata):
file_record.object_version = version
file_record.object_version_id = version.version_id

file_record.commit()
self.uow.register(RecordCommitOp(file_record))

# create the file instance that will be used to get the storage factory.
# it might also be used to initialize the file (preallocate its size)
file_instance = FileInstance.create()
db.session.add(file_instance)
self.uow.register(ModelCommitOp(file_instance))
version.set_file(file_instance)

storage = self._get_storage(
Expand All @@ -226,12 +231,12 @@ def init_file(self, record, file_metadata):
file_instance.set_uri(
storage.fileurl,
size,
checksum or "mutlipart:unknown",
checksum,
storage_class=storage_class,
)

db.session.add(file_instance)
file_record.commit() # updated transfer metadata, so need to commit
self.uow.register(ModelCommitOp(file_instance))
self.uow.register(RecordCommitOp(file_record))
return file_record

def set_file_content(self, stream, content_length):
Expand Down Expand Up @@ -276,11 +281,30 @@ def commit_file(self):
super().commit_file()

storage = self._get_storage()
storage.multipart_commit_upload(**self.multipart_metadata)
checksum = storage.multipart_commit_upload(**self.multipart_metadata)

recompute_checkum_needed = False

file_instance = self.file_record.object_version.file
if not file_instance.checksum:
recompute_checkum_needed = True
# get the multipart ETag and set it as the file checksum
if checksum is not None:
# set the checksum to the multipart checksum. This can later be picked
# up by a background job to compute the whole-file checksum reliably.
file_instance.checksum = (
f"multipart:{checksum}-{self.multipart_metadata['part_size']}"
)
self.uow.register(ModelCommitOp(file_instance))

# change the transfer type to local
self.file_record.transfer.transfer_type = LOCAL_TRANSFER_TYPE
self.file_record.commit()

if recompute_checkum_needed:
recompute_multipart_checksum_task.delay(
str(file_instance.id)
)

def delete_file(self):
"""If this method is called, we are deleting a file with an active multipart upload."""
Expand Down
2 changes: 1 addition & 1 deletion run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ trap cleanup EXIT
python -m check_manifest
python -m setup extract_messages --output-file /dev/null
python -m sphinx.cmd.build -qnNW docs docs/_build/html
eval "$(docker-services-cli up --db ${DB:-postgresql} --search ${SEARCH:-opensearch} --cache ${CACHE:-redis} --env)"
eval "$(docker-services-cli up --db ${DB:-postgresql} --search ${SEARCH:-opensearch} --cache ${CACHE:-redis} --s3 ${S3BACKEND:-minio} --env)"
python -m pytest $@
tests_exit_code=$?
python -m sphinx.cmd.build -qnNW -b doctest docs docs/_build/doctest
Expand Down
63 changes: 63 additions & 0 deletions tests/services/files/s3/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

Check failure on line 1 in tests/services/files/s3/conftest.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.9, postgresql14, opensearch2)

Black format check --- /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/conftest.py 2025-03-09 12:43:08.514964+00:00 +++ /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/conftest.py 2025-03-09 12:44:39.902799+00:00 @@ -39,10 +39,11 @@ app_config["S3_ACCESS_KEY_ID"] = os.environ["S3_ACCESS_KEY_ID"] app_config["S3_SECRET_ACCESS_KEY"] = os.environ["S3_SECRET_ACCESS_KEY"] return app_config + @pytest.fixture() def s3_location(app, db): """Creates an s3 location for a test.""" from invenio_files_rest.models import Location

Check failure on line 1 in tests/services/files/s3/conftest.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.9, postgresql14, opensearch2)

pydocstyle-check /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/conftest.py:1 at module level: D100: Missing docstring in public module

Check failure on line 1 in tests/services/files/s3/conftest.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.12, postgresql14, opensearch2)

Black format check --- /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/conftest.py 2025-03-09 12:43:08.162519+00:00 +++ /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/conftest.py 2025-03-09 12:45:04.657264+00:00 @@ -39,10 +39,11 @@ app_config["S3_ACCESS_KEY_ID"] = os.environ["S3_ACCESS_KEY_ID"] app_config["S3_SECRET_ACCESS_KEY"] = os.environ["S3_SECRET_ACCESS_KEY"] return app_config + @pytest.fixture() def s3_location(app, db): """Creates an s3 location for a test.""" from invenio_files_rest.models import Location

Check failure on line 1 in tests/services/files/s3/conftest.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.12, postgresql14, opensearch2)

pydocstyle-check /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/conftest.py:1 at module level: D100: Missing docstring in public module

import pytest
from mock_module.api import RecordWithFiles


#
# Need to redefine the app_config fixture to include the S3 configuration
# and remove the default file-based storage.
#
@pytest.fixture(scope="module")
def app_config(app_config):
"""Override pytest-invenio app_config fixture.
Needed to set the fields on the custom fields schema.
"""
app_config["RECORDS_RESOURCES_FILES_ALLOWED_DOMAINS"] = [
"inveniordm.test",
]
app_config["RECORDS_RESOURCES_FILES_ALLOWED_REMOTE_DOMAINS"] = [
"inveniordm.test",
]
app_config["FILES_REST_STORAGE_CLASS_LIST"] = {
"S": "Standard",
}

app_config["FILES_REST_DEFAULT_STORAGE_CLASS"] = "S"

app_config["RECORDS_REFRESOLVER_CLS"] = (
"invenio_records.resolver.InvenioRefResolver"
)
app_config["RECORDS_REFRESOLVER_STORE"] = (
"invenio_jsonschemas.proxies.current_refresolver_store"
)
app_config["FILES_REST_STORAGE_FACTORY"] = "invenio_s3.s3fs_storage_factory"

# s3 configuration
app_config["S3_ENDPOINT_URL"] = os.environ["S3_ENDPOINT_URL"]
app_config["S3_ACCESS_KEY_ID"] = os.environ["S3_ACCESS_KEY_ID"]
app_config["S3_SECRET_ACCESS_KEY"] = os.environ["S3_SECRET_ACCESS_KEY"]

return app_config

@pytest.fixture()
def s3_location(app, db):
"""Creates an s3 location for a test."""
from invenio_files_rest.models import Location

location_obj = Location(name="pytest-s3-location", uri="s3://default", default=True)

db.session.add(location_obj)
db.session.commit()

yield location_obj


@pytest.fixture()
def example_s3_file_record(db, input_data, s3_location):
"""Example record."""
record = RecordWithFiles.create({}, **input_data)
record.commit()
db.session.commit()
return record
95 changes: 95 additions & 0 deletions tests/services/files/s3/test_service_s3_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import base64

Check failure on line 1 in tests/services/files/s3/test_service_s3_backend.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.9, postgresql14, opensearch2)

Black format check --- /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/test_service_s3_backend.py 2025-03-09 12:43:08.514964+00:00 +++ /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/test_service_s3_backend.py 2025-03-09 12:44:40.279086+00:00 @@ -10,11 +10,11 @@ file_service, s3_location, example_s3_file_record, identity_simple, ): - + recid = example_s3_file_record["id"] key = "dataset.bin" total_size = 17 * 1024 * 1024 # 17MB part_size = 10 * 1024 * 1024 # 10MB @@ -90,6 +90,6 @@ result = file_service.get_file_content(identity_simple, recid, key) assert result.file_id == key # get the content from S3 and make sure it matches the original content sent_file = result.send_file() - assert content == requests.get(sent_file.headers['Location']).content + assert content == requests.get(sent_file.headers["Location"]).content

Check failure on line 1 in tests/services/files/s3/test_service_s3_backend.py

View workflow job for this annotation

GitHub Actions / Tests / Tests (3.12, postgresql14, opensearch2)

Black format check --- /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/test_service_s3_backend.py 2025-03-09 12:43:08.162519+00:00 +++ /home/runner/work/invenio-records-resources/invenio-records-resources/tests/services/files/s3/test_service_s3_backend.py 2025-03-09 12:45:05.217098+00:00 @@ -10,11 +10,11 @@ file_service, s3_location, example_s3_file_record, identity_simple, ): - + recid = example_s3_file_record["id"] key = "dataset.bin" total_size = 17 * 1024 * 1024 # 17MB part_size = 10 * 1024 * 1024 # 10MB @@ -90,6 +90,6 @@ result = file_service.get_file_content(identity_simple, recid, key) assert result.file_id == key # get the content from S3 and make sure it matches the original content sent_file = result.send_file() - assert content == requests.get(sent_file.headers['Location']).content + assert content == requests.get(sent_file.headers["Location"]).content
import hashlib
import struct

import requests


def test_multipart_file_upload_s3(
app,
file_service,
s3_location,
example_s3_file_record,
identity_simple,
):

recid = example_s3_file_record["id"]
key = "dataset.bin"
total_size = 17 * 1024 * 1024 # 17MB
part_size = 10 * 1024 * 1024 # 10MB

# total_size length, first 4 bytes are 00_00_00_00, second 00_00_00_01
content = b"".join(struct.pack("<I", idx) for idx in range(0, total_size // 4))

file_to_initialise = [
{
"key": key,
"size": total_size, # 2kB
"metadata": {
"description": "Published dataset.",
},
"transfer": {
"type": "M",
"parts": 2,
"part_size": part_size,
},
}
]
# Initialize file saving
result = file_service.init_files(identity_simple, recid, file_to_initialise)
result = result.to_dict()

assert result["entries"][0]["key"] == key
assert "parts" in result["entries"][0]["links"]

def upload_part(part_url, part_content, part_size):
part_checksum = base64.b64encode(hashlib.md5(part_content).digest())
resp = requests.put(
part_url,
data=part_content,
headers={
"Content-Length": str(part_size),
"Content-MD5": part_checksum,
},
)
if resp.status_code != 200:
raise Exception(f"Failed to upload part: {resp.text}")

parts_by_number = {
x["part"]: x["url"] for x in result["entries"][0]["links"]["parts"]
}

upload_part(parts_by_number[1], content[:part_size], part_size)

upload_part(parts_by_number[2], content[part_size:], total_size - part_size)

result = file_service.commit_file(identity_simple, recid, key).to_dict()
assert result["key"] == file_to_initialise[0]["key"]

# List files
result = file_service.list_files(identity_simple, recid).to_dict()
assert result["entries"][0]["key"] == file_to_initialise[0]["key"]
assert result["entries"][0]["storage_class"] == "S"

# Read file metadata
result = file_service.read_file_metadata(identity_simple, recid, key).to_dict()
assert result["key"] == file_to_initialise[0]["key"]
assert result["transfer"]["type"] == "L"

# Note: Invenio tests configure Celery tasks to run eagerly, so we cannot verify
# whether the following checksum has actually been generated and stored in the
# database. We assume it has been, as the processing of
# `recompute_multipart_checksum_task` would have failed otherwise.
# assert result["checksum"] == "multipart:562d3945b531e9c597d98b6bc7607a7d-2-10485760"

# Instead, we test that the final MD5 checksum has been generated by the
# `recompute_multipart_checksum_task`.
assert result["checksum"] == "md5:a5a5934a531b88a83b63f6a64611d177"

# Retrieve file
result = file_service.get_file_content(identity_simple, recid, key)
assert result.file_id == key

# get the content from S3 and make sure it matches the original content
sent_file = result.send_file()
assert content == requests.get(sent_file.headers['Location']).content

0 comments on commit d321cf5

Please sign in to comment.