diff --git a/invenio_vocabularies/cli.py b/invenio_vocabularies/cli.py index dca2df78..cd8a7391 100644 --- a/invenio_vocabularies/cli.py +++ b/invenio_vocabularies/cli.py @@ -13,11 +13,12 @@ import click from flask.cli import with_appcontext from invenio_access.permissions import system_identity +from invenio_logging.structlog import LoggerFactory from invenio_pidstore.errors import PIDDeletedError, PIDDoesNotExistError from .datastreams import DataStreamFactory from .factories import get_vocabulary_config -from invenio_logging.structlog import LoggerFactory + @click.group() def vocabularies(): @@ -35,9 +36,9 @@ def _process_vocab(config, num_samples=None): cli_logger.info("Starting processing") success, errored, filtered = 0, 0, 0 left = num_samples or -1 - batch_size=config.get("batch_size", 1000) - write_many=config.get("write_many", False) - + batch_size = config.get("batch_size", 1000) + write_many = config.get("write_many", False) + for result in ds.process(batch_size=batch_size, write_many=write_many): left = left - 1 if result.filtered: @@ -46,7 +47,12 @@ def _process_vocab(config, num_samples=None): if result.errors: for err in result.errors: click.secho(err, fg="red") - cli_logger.error("Error", entry=result.entry, operation=result.op_type, errors=result.errors) + cli_logger.error( + "Error", + entry=result.entry, + operation=result.op_type, + errors=result.errors, + ) errored += 1 else: success += 1 @@ -54,7 +60,9 @@ def _process_vocab(config, num_samples=None): if left == 0: click.secho(f"Number of samples reached {num_samples}", fg="green") break - cli_logger.info("Finished processing", success=success, errored=errored, filtered=filtered) + cli_logger.info( + "Finished processing", success=success, errored=errored, filtered=filtered + ) return success, errored, filtered @@ -159,7 +167,7 @@ def delete(vocabulary, identifier, all): if not identifier and not all: click.secho("An identifier or the --all flag must be present.", fg="red") exit(1) - + vc = get_vocabulary_config(vocabulary) service = vc.get_service() if identifier: @@ -175,4 +183,4 @@ def delete(vocabulary, identifier, all): if service.delete(system_identity, item["id"]): click.secho(f"{item['id']} deleted from {vocabulary}.", fg="green") except (PIDDeletedError, PIDDoesNotExistError): - click.secho(f"PID {item['id']} not found.") \ No newline at end of file + click.secho(f"PID {item['id']} not found.") diff --git a/invenio_vocabularies/config.py b/invenio_vocabularies/config.py index bf3be27c..4ebc484c 100644 --- a/invenio_vocabularies/config.py +++ b/invenio_vocabularies/config.py @@ -24,7 +24,7 @@ ZipReader, ) from .datastreams.transformers import XMLTransformer -from .datastreams.writers import AsyncWriter, AsyncWriter, ServiceWriter, YamlWriter +from .datastreams.writers import AsyncWriter, ServiceWriter, YamlWriter from .resources import VocabulariesResourceConfig from .services.config import VocabulariesServiceConfig @@ -156,13 +156,13 @@ } """Vocabulary type search configuration.""" -VOCABULARIES_ORCID_ACCESS_KEY="TOD" +VOCABULARIES_ORCID_ACCESS_KEY = "TOD" """ORCID access key to access the s3 bucket.""" -VOCABULARIES_ORCID_SECRET_KEY="TODO" +VOCABULARIES_ORCID_SECRET_KEY = "TODO" """ORCID secret key to access the s3 bucket.""" -VOCABULARIES_ORCID_SUMMARIES_BUCKET="v3.0-summaries" +VOCABULARIES_ORCID_SUMMARIES_BUCKET = "v3.0-summaries" """ORCID summaries bucket name.""" VOCABULARIES_ORCID_SYNC_MAX_WORKERS = 32 """ORCID max number of simultaneous workers/connections.""" VOCABULARIES_ORCID_SYNC_DAYS = 1 -"""ORCID number of days to sync.""" \ No newline at end of file +"""ORCID number of days to sync.""" diff --git a/invenio_vocabularies/contrib/names/datastreams.py b/invenio_vocabularies/contrib/names/datastreams.py index cb1b983e..c9b15dd9 100644 --- a/invenio_vocabularies/contrib/names/datastreams.py +++ b/invenio_vocabularies/contrib/names/datastreams.py @@ -8,19 +8,20 @@ """Names datastreams, transformers, writers and readers.""" +import io +import tarfile +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timedelta + +import s3fs +from flask import current_app from invenio_records.dictutils import dict_lookup from ...datastreams.errors import TransformerError -from ...datastreams.readers import SimpleHTTPReader, BaseReader +from ...datastreams.readers import BaseReader, SimpleHTTPReader from ...datastreams.transformers import BaseTransformer from ...datastreams.writers import ServiceWriter -import s3fs -from flask import current_app -from datetime import datetime -from datetime import timedelta -import tarfile -import io -from concurrent.futures import ThreadPoolExecutor, as_completed + class OrcidDataSyncReader(BaseReader): """ORCiD Data Sync Reader.""" @@ -29,33 +30,35 @@ def _fetch_orcid_data(self, orcid_to_sync, fs, bucket): """Fetches a single ORCiD record from S3.""" # The ORCiD file key is located in a folder which name corresponds to the last three digits of the ORCiD suffix = orcid_to_sync[-3:] - key = f'{suffix}/{orcid_to_sync}.xml' + key = f"{suffix}/{orcid_to_sync}.xml" try: - with fs.open(f's3://{bucket}/{key}', 'rb') as f: + with fs.open(f"s3://{bucket}/{key}", "rb") as f: file_response = f.read() return file_response except Exception as e: # TODO: log return None - + def _process_lambda_file(self, fileobj): """Process the ORCiD lambda file and returns a list of ORCiDs to sync. - + The decoded fileobj looks like the following: orcid,last_modified,created 0000-0001-5109-3700,2021-08-02 15:00:00.000,2021-08-02 15:00:00.000 - + Yield ORCiDs to sync until the last sync date is reached. """ - date_format = '%Y-%m-%d %H:%M:%S.%f' - date_format_no_millis = '%Y-%m-%d %H:%M:%S' - - last_sync = datetime.now() - timedelta(days=current_app.config["VOCABULARIES_ORCID_SYNC_DAYS"]) - - file_content = fileobj.read().decode('utf-8') - + date_format = "%Y-%m-%d %H:%M:%S.%f" + date_format_no_millis = "%Y-%m-%d %H:%M:%S" + + last_sync = datetime.now() - timedelta( + days=current_app.config["VOCABULARIES_ORCID_SYNC_DAYS"] + ) + + file_content = fileobj.read().decode("utf-8") + for line in file_content.splitlines()[1:]: # Skip the header line - elements = line.split(',') + elements = line.split(",") orcid = elements[0] # Lambda file is ordered by last modified date @@ -63,33 +66,43 @@ def _process_lambda_file(self, fileobj): try: last_modified_date = datetime.strptime(last_modified_str, date_format) except ValueError: - last_modified_date = datetime.strptime(last_modified_str, date_format_no_millis) + last_modified_date = datetime.strptime( + last_modified_str, date_format_no_millis + ) if last_modified_date >= last_sync: yield orcid else: break - def _iter(self, orcids, fs): """Iterates over the ORCiD records yielding each one.""" - with ThreadPoolExecutor(max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]) as executor: - futures = [executor.submit(self._fetch_orcid_data, orcid, fs, current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"]) for orcid in orcids] + with ThreadPoolExecutor( + max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"] + ) as executor: + futures = [ + executor.submit( + self._fetch_orcid_data, + orcid, + fs, + current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"], + ) + for orcid in orcids + ] for future in as_completed(futures): result = future.result() if result is not None: yield result - def read(self, item=None, *args, **kwargs): """Streams the ORCiD lambda file, process it to get the ORCiDS to sync and yields it's data.""" fs = s3fs.S3FileSystem( key=current_app.config["VOCABULARIES_ORCID_ACCESS_KEY"], - secret=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"] + secret=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"], ) # Read the file from S3 - with fs.open('s3://orcid-lambda-file/last_modified.csv.tar', 'rb') as f: + with fs.open("s3://orcid-lambda-file/last_modified.csv.tar", "rb") as f: tar_content = f.read() orcids_to_sync = [] @@ -102,9 +115,8 @@ def read(self, item=None, *args, **kwargs): if extracted_file: # Process the file and get the ORCiDs to sync orcids_to_sync.extend(self._process_lambda_file(extracted_file)) - + yield from self._iter(orcids_to_sync, fs) - class OrcidHTTPReader(SimpleHTTPReader): @@ -207,7 +219,7 @@ def _entry_id(self, entry): { "type": "async", "args": { - "writer":{ + "writer": { "type": "names-service", } }, diff --git a/invenio_vocabularies/datastreams/datastreams.py b/invenio_vocabularies/datastreams/datastreams.py index a2f851a3..37bf8b04 100644 --- a/invenio_vocabularies/datastreams/datastreams.py +++ b/invenio_vocabularies/datastreams/datastreams.py @@ -8,24 +8,26 @@ """Base data stream.""" -from .errors import ReaderError, TransformerError, WriterError from invenio_logging.structlog import LoggerFactory +from .errors import ReaderError, TransformerError, WriterError + + class StreamEntry: """Object to encapsulate streams processing.""" def __init__(self, entry, errors=None, op_type=None): - """Constructor for the StreamEntry class. - - Args: - entry (object): The entry object, usually a record dict. - errors (list, optional): List of errors. Defaults to None. - op_type (str, optional): The operation type. Defaults to None. - """ - self.entry = entry - self.filtered = False - self.errors = errors or [] - self.op_type = op_type + """Constructor for the StreamEntry class. + + :param entry (object): The entry object, usually a record dict. + :param errors (list, optional): List of errors. Defaults to None. + :param op_type (str, optional): The operation type. Defaults to None. + """ + self.entry = entry + self.filtered = False + self.errors = errors or [] + self.op_type = op_type + class DataStream: """Data stream.""" @@ -44,7 +46,7 @@ def __init__(self, readers, writers, transformers=None, *args, **kwargs): def filter(self, stream_entry, *args, **kwargs): """Checks if an stream_entry should be filtered out (skipped).""" return False - + def process_batch(self, batch, write_many=False): transformed_entries = [] for stream_entry in batch: @@ -77,9 +79,11 @@ def process(self, batch_size=100, write_many=False, logger=None, *args, **kwargs """ if not logger: logger = LoggerFactory.get_logger("datastreams") - + batch = [] - logger.info(f"Start reading datastream with batch_size={batch_size} and write_many={write_many}") + logger.info( + f"Start reading datastream with batch_size={batch_size} and write_many={write_many}" + ) for stream_entry in self.read(): batch.append(stream_entry) if len(batch) >= batch_size: @@ -136,7 +140,7 @@ def write(self, stream_entry, *args, **kwargs): stream_entry.errors.append(f"{writer.__class__.__name__}: {str(err)}") return stream_entry - + def batch_write(self, stream_entries, *args, **kwargs): """Apply the transformations to an stream_entry. Errors are handler in the service layer.""" for writer in self._writers: diff --git a/invenio_vocabularies/datastreams/readers.py b/invenio_vocabularies/datastreams/readers.py index 954b4a1b..cbef525e 100644 --- a/invenio_vocabularies/datastreams/readers.py +++ b/invenio_vocabularies/datastreams/readers.py @@ -26,7 +26,6 @@ from .errors import ReaderError from .xml import etree_to_dict - try: import oaipmh_scythe except ImportError: diff --git a/invenio_vocabularies/datastreams/tasks.py b/invenio_vocabularies/datastreams/tasks.py index 0b8cb196..dd0e443a 100644 --- a/invenio_vocabularies/datastreams/tasks.py +++ b/invenio_vocabularies/datastreams/tasks.py @@ -9,10 +9,11 @@ """Data Streams Celery tasks.""" from celery import shared_task +from invenio_logging.structlog import LoggerFactory from ..datastreams import StreamEntry from ..datastreams.factories import WriterFactory -from invenio_logging.structlog import LoggerFactory + @shared_task(ignore_result=True) def write_entry(writer_config, entry): @@ -24,6 +25,7 @@ def write_entry(writer_config, entry): writer = WriterFactory.create(config=writer_config) writer.write(StreamEntry(entry)) + @shared_task(ignore_result=True) def write_many_entry(writer_config, entries, logger=None): """Write many entries. @@ -41,4 +43,4 @@ def write_many_entry(writer_config, entries, logger=None): logger.info("Entries written", succeeded=succeeded) if errored: for entry in errored: - logger.error("Error writing entry", entry=entry.entry, errors=entry.errors) \ No newline at end of file + logger.error("Error writing entry", entry=entry.entry, errors=entry.errors) diff --git a/invenio_vocabularies/datastreams/writers.py b/invenio_vocabularies/datastreams/writers.py index b3e8c5f1..ab60ee42 100644 --- a/invenio_vocabularies/datastreams/writers.py +++ b/invenio_vocabularies/datastreams/writers.py @@ -50,6 +50,7 @@ def write_many(self, stream_entries, *args, **kwargs): """ pass + class ServiceWriter(BaseWriter): """Writes the entries to an RDM instance using a Service object.""" @@ -98,17 +99,21 @@ def write(self, stream_entry, *args, **kwargs): except InvalidRelationValue as err: # TODO: Check if we can get the error message easier raise WriterError([{"InvalidRelationValue": err.args[0]}]) - + def write_many(self, stream_entries, *args, **kwargs): entries = [entry.entry for entry in stream_entries] entries_with_id = [(self._entry_id(entry), entry) for entry in entries] records = self._service.create_or_update_many(self._identity, entries_with_id) - stream_entries_processed= [] + stream_entries_processed = [] for op_type, record, errors in records: if errors == []: - stream_entries_processed.append(StreamEntry(entry=record, op_type=op_type)) + stream_entries_processed.append( + StreamEntry(entry=record, op_type=op_type) + ) else: - stream_entries_processed.append(StreamEntry(entry=record, errors=errors, op_type=op_type)) + stream_entries_processed.append( + StreamEntry(entry=record, errors=errors, op_type=op_type) + ) return stream_entries_processed @@ -154,6 +159,8 @@ def write(self, stream_entry, *args, **kwargs): def write_many(self, stream_entries, *args, **kwargs): """Launches a celery task to write an entry.""" - write_many_entry.delay(self._writer, [stream_entry.entry for stream_entry in stream_entries]) + write_many_entry.delay( + self._writer, [stream_entry.entry for stream_entry in stream_entries] + ) - return stream_entries \ No newline at end of file + return stream_entries diff --git a/invenio_vocabularies/services/tasks.py b/invenio_vocabularies/services/tasks.py index a40dd6d2..7a93fc4d 100644 --- a/invenio_vocabularies/services/tasks.py +++ b/invenio_vocabularies/services/tasks.py @@ -9,10 +9,10 @@ from celery import shared_task from flask import current_app +from invenio_logging.structlog import LoggerFactory from ..datastreams.factories import DataStreamFactory from ..factories import get_vocabulary_config -from invenio_logging.structlog import LoggerFactory @shared_task(ignore_result=True) @@ -27,7 +27,7 @@ def process_datastream(stream): if not config: stream_logger.error("Invalid stream configuration") raise ValueError("Invalid stream configuration") - + ds = DataStreamFactory.create( readers_config=config["readers"], transformers_config=config.get("transformers"), @@ -43,14 +43,25 @@ def process_datastream(stream): ): if result.filtered: filtered += 1 - stream_logger.info("Filtered", entry=result.entry, operation=result.op_type) + stream_logger.info( + "Filtered", entry=result.entry, operation=result.op_type + ) if result.errors: errored += 1 - stream_logger.error("Error", entry=result.entry, operation=result.op_type, errors=result.errors) + stream_logger.error( + "Error", + entry=result.entry, + operation=result.op_type, + errors=result.errors, + ) else: success += 1 - stream_logger.info("Success", entry=result.entry, operation=result.op_type) - stream_logger.info("Finished processing", success=success, errored=errored, filtered=filtered) + stream_logger.info( + "Success", entry=result.entry, operation=result.op_type + ) + stream_logger.info( + "Finished processing", success=success, errored=errored, filtered=filtered + ) except Exception as e: stream_logger.exception("Error processing stream", error=e) diff --git a/tests/conftest.py b/tests/conftest.py index d2e27270..a6d8fc5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,12 +79,12 @@ def app_config(app_config): app_config["JSONSCHEMAS_HOST"] = "localhost" app_config["BABEL_DEFAULT_LOCALE"] = "en" app_config["I18N_LANGUAGES"] = [("da", "Danish")] - app_config["RECORDS_REFRESOLVER_CLS"] = ( - "invenio_records.resolver.InvenioRefResolver" - ) - app_config["RECORDS_REFRESOLVER_STORE"] = ( - "invenio_jsonschemas.proxies.current_refresolver_store" - ) + app_config[ + "RECORDS_REFRESOLVER_CLS" + ] = "invenio_records.resolver.InvenioRefResolver" + app_config[ + "RECORDS_REFRESOLVER_STORE" + ] = "invenio_jsonschemas.proxies.current_refresolver_store" return app_config diff --git a/tests/datastreams/conftest.py b/tests/datastreams/conftest.py index 9c31960c..f727ed30 100644 --- a/tests/datastreams/conftest.py +++ b/tests/datastreams/conftest.py @@ -18,8 +18,11 @@ import pytest -from invenio_vocabularies.config import VOCABULARIES_DATASTREAM_READERS, \ - VOCABULARIES_DATASTREAM_TRANSFORMERS, VOCABULARIES_DATASTREAM_WRITERS +from invenio_vocabularies.config import ( + VOCABULARIES_DATASTREAM_READERS, + VOCABULARIES_DATASTREAM_TRANSFORMERS, + VOCABULARIES_DATASTREAM_WRITERS, +) from invenio_vocabularies.datastreams.errors import TransformerError, WriterError from invenio_vocabularies.datastreams.readers import BaseReader, JsonReader, ZipReader from invenio_vocabularies.datastreams.transformers import BaseTransformer @@ -82,7 +85,7 @@ def app_config(app_config): } app_config["VOCABULARIES_DATASTREAM_TRANSFORMERS"] = { **VOCABULARIES_DATASTREAM_TRANSFORMERS, - "test": TestTransformer + "test": TestTransformer, } app_config["VOCABULARIES_DATASTREAM_WRITERS"] = { **VOCABULARIES_DATASTREAM_WRITERS, diff --git a/tests/datastreams/test_datastreams_tasks.py b/tests/datastreams/test_datastreams_tasks.py index d907422f..55b72ff0 100644 --- a/tests/datastreams/test_datastreams_tasks.py +++ b/tests/datastreams/test_datastreams_tasks.py @@ -17,13 +17,8 @@ def test_write_entry(app): - filepath = 'writer_test.yaml' - yaml_writer_config = { - "type": "yaml", - "args": { - "filepath": filepath - } - } + filepath = "writer_test.yaml" + yaml_writer_config = {"type": "yaml", "args": {"filepath": filepath}} entry = {"key_one": [{"inner_one": 1}]} write_entry(yaml_writer_config, entry) diff --git a/tests/datastreams/test_writers.py b/tests/datastreams/test_writers.py index d25d2c2f..1f70cb57 100644 --- a/tests/datastreams/test_writers.py +++ b/tests/datastreams/test_writers.py @@ -16,8 +16,11 @@ from invenio_vocabularies.datastreams import StreamEntry from invenio_vocabularies.datastreams.errors import WriterError -from invenio_vocabularies.datastreams.writers import AsyncWriter, \ - ServiceWriter, YamlWriter +from invenio_vocabularies.datastreams.writers import ( + AsyncWriter, + ServiceWriter, + YamlWriter, +) ## # Service Writer @@ -72,6 +75,7 @@ def test_service_writer_update_non_existing(lang_type, lang_data, service, ident assert dict(record, **updated_lang) == record + ## # YAML Writer ## @@ -90,25 +94,18 @@ def test_yaml_writer(): filepath.unlink() + ## # Async Writer ## def test_async_writer(app): - filepath = 'writer_test.yaml' - yaml_writer_config = { - "type": "yaml", - "args": { - "filepath": filepath - } - } + filepath = "writer_test.yaml" + yaml_writer_config = {"type": "yaml", "args": {"filepath": filepath}} async_writer = AsyncWriter(yaml_writer_config) - test_output = [ - {"key_one": [{"inner_one": 1}]}, - {"key_two": [{"inner_two": "two"}]} - ] + test_output = [{"key_one": [{"inner_one": 1}]}, {"key_two": [{"inner_two": "two"}]}] for output in test_output: async_writer.write(stream_entry=StreamEntry(output))