diff --git a/invenio_vocabularies/cli.py b/invenio_vocabularies/cli.py index a9e1e55a..07187f1a 100644 --- a/invenio_vocabularies/cli.py +++ b/invenio_vocabularies/cli.py @@ -26,6 +26,8 @@ def vocabularies(): def _process_vocab(config, num_samples=None): """Import a vocabulary.""" + import time + start_time = time.time() ds = DataStreamFactory.create( readers_config=config["readers"], transformers_config=config.get("transformers"), @@ -34,7 +36,8 @@ def _process_vocab(config, num_samples=None): success, errored, filtered = 0, 0, 0 left = num_samples or -1 - for result in ds.process(): + for result in ds.process(batch_size=config["batch_size"] if "batch_size" in config else 100 + ,write_many=config["write_many"] if "write_many" in config else False): left = left - 1 if result.filtered: filtered += 1 @@ -47,6 +50,20 @@ def _process_vocab(config, num_samples=None): if left == 0: click.secho(f"Number of samples reached {num_samples}", fg="green") break + + end_time = time.time() + + elapsed_time = end_time - start_time + friendly_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) + friendly_time_per_record = 0 + if success: + elapsed_time_per_record = elapsed_time/success * 1000 + friendly_time_per_record = time.strftime("%H:%M:%S", time.gmtime(elapsed_time_per_record)) + + print(f"CLI elapsed time: {friendly_time} for {success} entries. An average of {friendly_time_per_record} per 1000 entry.\n") + with open("/tmp/elapsed_time.txt", "a") as file: + file.write(f"CLI elapsed time: {friendly_time} for {success} entries. An average of {friendly_time_per_record} per 1000 entry.\n") + return success, errored, filtered @@ -101,7 +118,10 @@ def update(vocabulary, filepath=None, origin=None): config = vc.get_config(filepath, origin) for w_conf in config["writers"]: - w_conf["args"]["update"] = True + if w_conf["type"] == "async": + w_conf["args"]["writer"]["args"]["update"] = True + else: + w_conf["args"]["update"] = True success, errored, filtered = _process_vocab(config) diff --git a/invenio_vocabularies/config.py b/invenio_vocabularies/config.py index 27ac282d..9780a5e5 100644 --- a/invenio_vocabularies/config.py +++ b/invenio_vocabularies/config.py @@ -24,7 +24,7 @@ ZipReader, ) from .datastreams.transformers import XMLTransformer -from .datastreams.writers import ServiceWriter, YamlWriter +from .datastreams.writers import AsyncWriter, ServiceWriter, YamlWriter from .resources import VocabulariesResourceConfig from .services.config import VocabulariesServiceConfig @@ -134,6 +134,7 @@ VOCABULARIES_DATASTREAM_WRITERS = { "service": ServiceWriter, "yaml": YamlWriter, + "async": AsyncWriter, } """Data Streams writers.""" @@ -154,3 +155,9 @@ "sort": ["name", "count"], } """Vocabulary type search configuration.""" + +VOCABULARIES_ORCID_ACCESS_KEY="CHANGE_ME" +VOCABULARIES_ORCID_SECRET_KEY="CHANGE_ME" +VOCABULARIES_ORCID_FOLDER="/tmp/ORCID_public_data_files/" +VOCABULARIES_ORCID_SUMMARIES_BUCKET="v3.0-summaries" +VOCABULARIES_DATASTREAM_BATCH_SIZE = 100 diff --git a/invenio_vocabularies/contrib/names/datastreams.py b/invenio_vocabularies/contrib/names/datastreams.py index ce050f3e..2296042e 100644 --- a/invenio_vocabularies/contrib/names/datastreams.py +++ b/invenio_vocabularies/contrib/names/datastreams.py @@ -12,9 +12,94 @@ from invenio_records.dictutils import dict_lookup from ...datastreams.errors import TransformerError -from ...datastreams.readers import SimpleHTTPReader +from ...datastreams.readers import SimpleHTTPReader, BaseReader from ...datastreams.transformers import BaseTransformer from ...datastreams.writers import ServiceWriter +import boto3 +from flask import current_app +from datetime import datetime +from datetime import timedelta +import tarfile +import io +from concurrent.futures import ThreadPoolExecutor + +class OrcidDataSyncReader(BaseReader): + """ORCiD Data Sync Reader.""" + + def _iter(self, fp, *args, **kwargs): + """.""" + raise NotImplementedError( + "OrcidDataSyncReader downloads one file and therefore does not iterate through items" + ) + + def read(self, item=None, *args, **kwargs): + """Downloads the ORCiD lambda file and yields an in-memory binary stream of it.""" + + path = current_app.config["VOCABULARIES_ORCID_FOLDER"] + date_format = '%Y-%m-%d %H:%M:%S.%f' + date_format_no_millis = '%Y-%m-%d %H:%M:%S' + + s3client = boto3.client('s3', aws_access_key_id=current_app.config["VOCABULARIES_ORCID_ACCESS_KEY"], aws_secret_access_key=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"]) + response = s3client.get_object(Bucket='orcid-lambda-file', Key='last_modified.csv.tar') + tar_content = response['Body'].read() + + days_to_sync = 60*9 + last_sync = datetime.now() - timedelta(minutes=days_to_sync) + # TODO: Do we want to use last_run to kee keep track of the last time the sync was run? + # Might not be ideal as it seems the file is updated at midnight + + # last_ran_path = os.path.join(path, 'last_ran.config') + # if os.path.isfile(last_ran_path): + # with open(last_ran_path, 'r') as f: + # date_string = f.readline() + # last_sync = datetime.strptime(date_string, date_format) + + # with open(last_ran_path, 'w') as f: + # f.write(datetime.now().strftime(date_format)) + + + def process_file(fileobj): + file_content = fileobj.read().decode('utf-8') + orcids = [] + for line in file_content.splitlines()[1:]: # Skip the header line + elements = line.split(',') + orcid = elements[0] + + last_modified_str = elements[3] + try: + last_modified_date = datetime.strptime(last_modified_str, date_format) + except ValueError: + last_modified_date = datetime.strptime(last_modified_str, date_format_no_millis) + + if last_modified_date >= last_sync: + orcids.append(orcid) + else: + break + return orcids + + orcids_to_sync = [] + with tarfile.open(fileobj=io.BytesIO(tar_content)) as tar: + for member in tar.getmembers(): + f = tar.extractfile(member) + if f: + orcids_to_sync.extend(process_file(f)) + + def fetch_orcid_data(orcid_to_sync, bucket): + suffix = orcid_to_sync[-3:] + key = f'{suffix}/{orcid_to_sync}.xml' + try: + file_response = s3client.get_object(Bucket=bucket, Key=key) + return file_response['Body'].read() + except Exception as e: + # TODO: log + return None + + with ThreadPoolExecutor(max_workers=40) as executor: # TODO allow to configure max_workers / test to use asyncio + futures = [executor.submit(fetch_orcid_data, orcid, current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"]) for orcid in orcids_to_sync] + for future in futures: + result = future.result() + if result is not None: + yield result class OrcidHTTPReader(SimpleHTTPReader): @@ -89,6 +174,7 @@ def _entry_id(self, entry): VOCABULARIES_DATASTREAM_READERS = { "orcid-http": OrcidHTTPReader, + "orcid-data-sync": OrcidDataSyncReader, } @@ -107,22 +193,32 @@ def _entry_id(self, entry): DATASTREAM_CONFIG = { "readers": [ { - "type": "tar", - "args": { - "regex": "\\.xml$", - }, + "type": "orcid-data-sync", }, {"type": "xml"}, ], "transformers": [{"type": "orcid"}], + # "writers": [ + # { + # "type": "names-service", + # "args": { + # "identity": system_identity, + # }, + # } + # ], "writers": [ { - "type": "names-service", + "type": "async", "args": { - "identity": system_identity, + "writer":{ + "type": "names-service", + "args": {}, + } }, } ], + "batch_size": 1000, # TODO: current_app.config["VOCABULARIES_DATASTREAM_BATCH_SIZE"], + "write_many": True, } """ORCiD Data Stream configuration. diff --git a/invenio_vocabularies/datastreams/datastreams.py b/invenio_vocabularies/datastreams/datastreams.py index 3fc2d1e4..bd5f0593 100644 --- a/invenio_vocabularies/datastreams/datastreams.py +++ b/invenio_vocabularies/datastreams/datastreams.py @@ -10,7 +10,6 @@ from .errors import ReaderError, TransformerError, WriterError - class StreamEntry: """Object to encapsulate streams processing.""" @@ -38,16 +37,10 @@ def __init__(self, readers, writers, transformers=None, *args, **kwargs): def filter(self, stream_entry, *args, **kwargs): """Checks if an stream_entry should be filtered out (skipped).""" return False - - def process(self, *args, **kwargs): - """Iterates over the entries. - - Uses the reader to get the raw entries and transforms them. - It will iterate over the `StreamEntry` objects returned by - the reader, apply the transformations and yield the result of - writing it. - """ - for stream_entry in self.read(): + + def process_batch(self, batch, write_many=False): + transformed_entries = [] + for stream_entry in batch: if stream_entry.errors: yield stream_entry # reading errors else: @@ -58,7 +51,33 @@ def process(self, *args, **kwargs): transformed_entry.filtered = True yield transformed_entry else: - yield self.write(transformed_entry) + transformed_entries.append(transformed_entry) + if transformed_entries: + if write_many: + print(f"write_many {len(transformed_entries)} entries.") + yield from self.batch_write(transformed_entries) + else: + print(f"write {len(transformed_entries)} entries.") + yield from (self.write(entry) for entry in transformed_entries) + + def process(self, batch_size=100, write_many=False, *args, **kwargs): + """Iterates over the entries. + + Uses the reader to get the raw entries and transforms them. + It will iterate over the `StreamEntry` objects returned by + the reader, apply the transformations and yield the result of + writing it. + """ + batch = [] + for stream_entry in self.read(): + batch.append(stream_entry) + if len(batch) >= batch_size: + yield from self.process_batch(batch, write_many=write_many) + batch = [] + + # Process any remaining entries in the last batch + if batch: + yield from self.process_batch(batch, write_many=write_many) def read(self): """Recursively read the entries.""" @@ -106,6 +125,20 @@ def write(self, stream_entry, *args, **kwargs): stream_entry.errors.append(f"{writer.__class__.__name__}: {str(err)}") return stream_entry + + def batch_write(self, stream_entries, *args, **kwargs): + """Apply the transformations to an stream_entry.""" + for writer in self._writers: + try: + success, errors = writer.write_many(stream_entries) + for record in success: + yield StreamEntry(entry=record) + for error in errors: + yield StreamEntry(entry=error["record"], errors=error["errors"]) + except WriterError as err: + for stream_entry in stream_entries: + stream_entry.errors.append(f"{writer.__class__.__name__}: {str(err)}") + yield stream_entry def total(self, *args, **kwargs): """The total of entries obtained from the origin.""" diff --git a/invenio_vocabularies/datastreams/readers.py b/invenio_vocabularies/datastreams/readers.py index 736c44ca..954b4a1b 100644 --- a/invenio_vocabularies/datastreams/readers.py +++ b/invenio_vocabularies/datastreams/readers.py @@ -21,11 +21,12 @@ import requests import yaml from lxml import etree -from lxml.html import parse as html_parse +from lxml.html import fromstring from .errors import ReaderError from .xml import etree_to_dict + try: import oaipmh_scythe except ImportError: @@ -226,8 +227,8 @@ class XMLReader(BaseReader): def _iter(self, fp, *args, **kwargs): """Read and parse an XML file to dict.""" # NOTE: We parse HTML, to skip XML validation and strip XML namespaces - xml_tree = html_parse(fp).getroot() - record = etree_to_dict(xml_tree)["html"]["body"].get("record") + xml_tree = fromstring(fp) + record = etree_to_dict(xml_tree).get("record") if not record: raise ReaderError(f"Record not found in XML entry.") diff --git a/invenio_vocabularies/datastreams/tasks.py b/invenio_vocabularies/datastreams/tasks.py index 9407c051..ac5e0ad6 100644 --- a/invenio_vocabularies/datastreams/tasks.py +++ b/invenio_vocabularies/datastreams/tasks.py @@ -15,11 +15,22 @@ @shared_task(ignore_result=True) -def write_entry(writer, entry): +def write_entry(writer_config, entry): """Write an entry. :param writer: writer configuration as accepted by the WriterFactory. :param entry: dictionary, StreamEntry is not serializable. """ - writer = WriterFactory.create(config=writer) + writer = WriterFactory.create(config=writer_config) writer.write(StreamEntry(entry)) + +@shared_task(ignore_result=True) +def write_many_entry(writer_config, entries): + """Write many entries. + + :param writer: writer configuration as accepted by the WriterFactory. + :param entry: lisf ot dictionaries, StreamEntry is not serializable. + """ + writer = WriterFactory.create(config=writer_config) + stream_entries = [StreamEntry(entry) for entry in entries] + writer.write_many(stream_entries) diff --git a/invenio_vocabularies/datastreams/writers.py b/invenio_vocabularies/datastreams/writers.py index abb63dca..bad6c673 100644 --- a/invenio_vocabularies/datastreams/writers.py +++ b/invenio_vocabularies/datastreams/writers.py @@ -20,12 +20,17 @@ from .datastreams import StreamEntry from .errors import WriterError -from .tasks import write_entry +from .tasks import write_entry, write_many_entry class BaseWriter(ABC): """Base writer.""" + def __init__(self, *args, **kwargs): + """Base initialization logic.""" + # Add any base initialization here if needed + pass + @abstractmethod def write(self, stream_entry, *args, **kwargs): """Writes the input stream entry to the target output. @@ -36,6 +41,14 @@ def write(self, stream_entry, *args, **kwargs): """ pass + def write_many(self, stream_entries, *args, **kwargs): + """Writes the input streams entry to the target output. + + :returns: A List of StreamEntry. The result of writing the entry. + Raises WriterException in case of errors. + + """ + pass class ServiceWriter(BaseWriter): """Writes the entries to an RDM instance using a Service object.""" @@ -85,6 +98,11 @@ def write(self, stream_entry, *args, **kwargs): except InvalidRelationValue as err: # TODO: Check if we can get the error message easier raise WriterError([{"InvalidRelationValue": err.args[0]}]) + + def write_many(self, stream_entries, *args, **kwargs): + entries = [entry.entry for entry in stream_entries] + entries_with_id = [(self._entry_id(entry), entry) for entry in entries] + return self._service.create_update_many(self._identity, entries_with_id) class YamlWriter(BaseWriter): @@ -117,11 +135,17 @@ def __init__(self, writer, *args, **kwargs): :param writer: writer to use. """ - self._writer = writer super().__init__(*args, **kwargs) + self._writer = writer def write(self, stream_entry, *args, **kwargs): """Launches a celery task to write an entry.""" write_entry.delay(self._writer, stream_entry.entry) return stream_entry + + def write_many(self, stream_entries, *args, **kwargs): + """Launches a celery task to write an entry.""" + write_many_entry.delay(self._writer, [stream_entry.entry for stream_entry in stream_entries]) + + return stream_entries, [] # TODO: Returning this way for consistency with other writers. It's assuming all succeded... \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index eb426a95..9efd4233 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,10 @@ install_requires = invenio-administration>=2.0.0,<3.0.0 lxml>=4.5.0 PyYAML>=5.4.1 + awscli>=1.33.23 + boto3>=1.12.6 + botocore>=1.34.141 + iso8601>=0.1.11 [options.extras_require] oaipmh =