Skip to content

[WIP/REF] wip/weep #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 156 additions & 90 deletions ns_extract/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

from abc import ABC, abstractmethod
import concurrent.futures
from ns_extract.pipelines.normalize import (
normalize_string,
load_abbreviations,
resolve_abbreviations,
)
from datetime import datetime
from functools import reduce
import hashlib
Expand All @@ -31,6 +36,7 @@
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -685,8 +691,8 @@ def transform_dataset(
f"Failed to load inputs for study {db_id}: {str(e)}"
)

# Process loaded inputs and get results
transform_outputs = self.transform(loaded_study_inputs, **kwargs)
# Process loaded inputs with fit_transform
transform_outputs = self.fit_transform(loaded_study_inputs, **kwargs)
cleaned_results, raw_results, validation_status = transform_outputs

if cleaned_results:
Expand Down Expand Up @@ -735,49 +741,31 @@ def transform_dataset(
logger.error(f"Failed to cleanup after error: {str(cleanup_error)}")
raise

def validate_results(self, results, **kwargs):
"""Apply validation to each study's results individually.

Args:
results: Dict mapping study IDs to their results
**kwargs: Additional validation arguments

Returns:
Tuple of:
- Dict mapping study IDs to their validated results
- Dict mapping study IDs to validation status (True/False)
"""
validation_status = {}

for db_id, study_results in results.items():
try:
# Validate each study's results against the schema
self._output_schema.model_validate(study_results)
validation_status[db_id] = True
except Exception as e:
logging.error(f"Output validation error for study {db_id}: {str(e)}")
validation_status[db_id] = False

return validation_status
# Removed validate_results method - now using unified validate() method


class Extractor(ABC):
"""Base class for data transformation logic.

This class defines the core interface for transforming input study data into
structured outputs validated against a schema. It separates the transformation
logic from I/O operations (handled by Pipeline classes).
structured outputs validated against a schema. It follows the scikit-learn
fit/transform pattern while supporting both pre-trained and trainable models.

Required Class Variables:
_version: str - Version identifier for the extractor implementation
_output_schema: Type[BaseModel] - Pydantic model defining expected output format

Key Methods:
_transform: Core transformation logic (must be implemented by subclasses)
transform: Main entry point that coordinates transformation and validation
validate_output: Validates outputs against the schema
fit: Train the extractor on input data (must be implemented by subclasses)
transform: Transform input data (must be implemented by subclasses)
fit_transform: Convenience method to fit and transform in one step
validate: Validate outputs against the schema
post_process: Optional hook for modifying results before validation

The interface handles both independent and dependent pipeline processing:
- Independent: processes individual studies (single dict input)
- Dependent: processes all studies together (dict of dicts input)

Example:
>>> class MyExtractor(Extractor[MyOutputSchema]):
... _version = "1.0.0"
Expand All @@ -791,10 +779,15 @@ class Extractor(ABC):
_version: str = None
_output_schema: Type[BaseModel] = None

def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
expand_abbreviations_fields: List = [],
normalizable_string_fields: List = [],
**kwargs: Any) -> None:
"""Initialize extractor and verify required configuration.

Args:
expand_abbreviations_fields: List of fields to apply normalization to
**kwargs: Configuration parameters for the extractor

Raises:
Expand All @@ -805,6 +798,21 @@ def __init__(self, **kwargs: Any) -> None:
if not self._version:
raise ValueError("Subclass must define _version class variable")

self.normalizable_string_fields = normalizable_string_fields
self.expand_abbreviations_fields = expand_abbreviations_fields
self._nlp = None

# Pre-load NLP model if we'll need it
if expand_abbreviations_fields:
import spacy
try:
self._nlp = spacy.load("en_core_sci_sm", disable=["parser", "ner"])
if "abbreviation_detector" not in self._nlp.pipe_names:
import scispacy # noqa: F401
self._nlp.add_pipe("abbreviation_detector")
except Exception as e:
print(f"Warning: Failed to load NLP model: {e}")

if isinstance(self, IndependentPipeline):
IndependentPipeline.__init__(self, extractor=self)
if isinstance(self, DependentPipeline):
Expand All @@ -813,8 +821,8 @@ def __init__(self, **kwargs: Any) -> None:
@abstractmethod
def _transform(
self, inputs: Dict[str, Dict[str, Any]], **kwargs: Any
) -> Dict[str, Any]:
"""Transform input data into output format.
) -> None:
"""Implementation of model training logic.

This is the core transformation method that must be implemented by subclasses.
It should convert raw study data into the expected output format defined by T.
Expand All @@ -827,90 +835,148 @@ def _transform(
...
}
}
**kwargs: Additional transformation arguments

Returns:
Dict mapping study IDs to their transformed outputs
Format: {
"study_id": transformed_data # type T
}
For independent pipelines, will contain single study.
For dependent pipelines, will contain all studies.
**kwargs: Additional training arguments

Raises:
ProcessingError: If transformation fails for any study
ProcessingError: If training fails
"""
pass

def transform(
self, inputs: Dict[str, Dict[str, Any]], **kwargs: Any
) -> Tuple[Dict[str, Any], Dict[str, Any], Union[bool, Dict[str, bool]]]:
"""Transform and validate input data.
def _process_text_field(
self,
text: str,
field_name: str,
abbreviations: Optional[List[Dict]] = None
) -> str:
"""Process a text field with abbreviation expansion and/or normalization."""
if not isinstance(text, str):
return text

# First expand abbreviations if needed
if field_name in self.expand_abbreviations_fields and abbreviations:
text = resolve_abbreviations(text, abbreviations)

# Then normalize if needed
if field_name in self.normalizable_string_fields:
text = normalize_string(text)

return text

def _normalize_nested_fields(
self,
data: Any,
abbreviations: Optional[List[Dict]] = None
) -> Any:
"""Recursively process fields in nested data structures."""
if isinstance(data, dict):
normalized = {}
for key, value in data.items():
if isinstance(value, str):
normalized[key] = self._process_text_field(value, key, abbreviations)
else:
normalized[key] = self._normalize_nested_fields(value, abbreviations)
return normalized
elif isinstance(data, list):
return [self._normalize_nested_fields(item, abbreviations) for item in data]
return data

This method orchestrates the complete transformation process:
1. Calls _transform for core data processing
2. Applies optional post-processing
3. Validates results against schema
def post_process(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Optional hook for post-processing transform results.

This can be overridden by subclasses to modify results before validation,
for example to clean or normalize data.

Args:
inputs: Dict mapping study IDs to their input data
**kwargs: Additional transformation arguments
results: Dict mapping study IDs to their raw transformed data

Returns:
Tuple containing:
- Post-processed results (Dict[str, T])
- Raw results before post-processing (Dict[str, T])
- Validation status:
- bool for IndependentPipeline (single result)
- Dict[str, bool] for DependentPipeline (per-study validation)

Raises:
ProcessingError: If transformation fails
Dict with same structure as input, potentially modified
"""
# Get raw results from transform
raw_results = self._transform(inputs, **kwargs)

# Post-process results
cleaned_results = self.post_process(raw_results)

# Validate results based on pipeline type
if isinstance(self, DependentPipeline):
# For dependent pipelines, validate each study individually
validation_status = self.validate_results(cleaned_results)
# Process each study with its own abbreviations
processed_results = {}
for study_id, study_data in results.items():
# Extract abbreviations from this study's source text
source_inputs = next(s for s in study_data.values() if isinstance(s, dict))
abbreviations = self._extract_source_abbreviations(source_inputs)
# Process this study's data with its abbreviations
processed_results[study_id] = self._normalize_nested_fields(
study_data, abbreviations
)
return processed_results
else:
# For independent pipelines, validate single study
validation_status = self.validate_output(cleaned_results)
# Process single study with its abbreviations
abbreviations = self._extract_source_abbreviations(results)
return self._normalize_nested_fields(results, abbreviations)

return (cleaned_results, raw_results, validation_status)

def validate_output(self, output: Dict[str, Any]) -> bool:
def validate(self, output: Dict[str, Any]) -> Union[bool, Dict[str, bool]]:
"""Validate transformed data against schema.

Unified validation method that handles both independent and dependent
pipeline cases.

Args:
output: Dict mapping study IDs to transformed data

Returns:
bool: True if validation passes, False otherwise
For independent pipelines: bool indicating if validation passed
For dependent pipelines: Dict mapping study IDs to validation status

Note:
Validation failures are logged but don't raise exceptions
to allow graceful handling of invalid results
"""
try:
self._output_schema.model_validate(output)
return True
except Exception as e:
logging.error(f"Output validation error: {str(e)}")
return False

def post_process(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Optional hook for post-processing transform results.
if isinstance(self, DependentPipeline):
# Validate each study individually
validation_status = {}
for study_id, study_output in output.items():
try:
self._output_schema.model_validate(study_output)
validation_status[study_id] = True
except Exception as e:
logging.error(f"Output validation error for study {study_id}: {str(e)}")
validation_status[study_id] = False
return validation_status
else:
# Validate single output
try:
self._output_schema.model_validate(output)
return True
except Exception as e:
logging.error(f"Output validation error: {str(e)}")
return False

This can be overridden by subclasses to modify results before validation,
for example to clean or normalize data.
def _extract_source_abbreviations(self, inputs: Dict[str, Dict[str, Any]]) -> List[Dict]:
"""Extract abbreviations from the source text (ace or pubget).

Args:
results: Dict mapping study IDs to their raw transformed data
inputs: Dict containing input data with source text

Returns:
Dict with same structure as input, potentially modified
List of abbreviation dictionaries from the source text
"""
return results
if not self._nlp or not self.expand_abbreviations_fields:
return []

# Look for text in data pond inputs with priority order
for sources, input_types in self._data_pond_inputs.items():
if "text" not in input_types:
continue

for source in sources: # e.g., try pubget first, then ace
if source not in inputs:
continue

if "text" not in inputs[source]:
continue

source_text = inputs[source]["text"]
if not isinstance(source_text, str):
continue

# Found valid source text, extract abbreviations
return load_abbreviations(source_text, model=self._nlp)

return []
Loading
Loading