Skip to content

Commit e13b9c5

Browse files
committed
feat: initial work to add xsd validation in xml readers
1 parent adfc216 commit e13b9c5

File tree

9 files changed

+310
-416
lines changed

9 files changed

+310
-416
lines changed

src/dve/core_engine/backends/implementations/duckdb/readers/xml.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,33 @@
1212
from dve.core_engine.backends.readers.xml import XMLStreamReader
1313
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model
1414
from dve.core_engine.type_hints import URI
15+
from dve.parser.file_handling.service import get_parent
16+
from dve.pipeline.utils import dump_errors
1517

1618

1719
@duckdb_write_parquet
1820
class DuckDBXMLStreamReader(XMLStreamReader):
1921
"""A reader for XML files"""
2022

21-
def __init__(self, ddb_connection: Optional[DuckDBPyConnection] = None, **kwargs):
23+
def __init__(self,
24+
ddb_connection: Optional[DuckDBPyConnection] = None,
25+
**kwargs):
2226
self.ddb_connection = ddb_connection if ddb_connection else default_connection
2327
super().__init__(**kwargs)
2428

2529
@read_function(DuckDBPyRelation)
2630
def read_to_relation(self, resource: URI, entity_name: str, schema: Type[BaseModel]):
2731
"""Returns a relation object from the source xml"""
32+
if self.xsd_location:
33+
msg = self._run_xmllint(file_uri=resource)
34+
if msg:
35+
working_folder = get_parent(resource)
36+
dump_errors(
37+
working_folder=working_folder,
38+
step_name="file_transformation",
39+
messages=[msg]
40+
)
41+
2842
polars_schema: Dict[str, pl.DataType] = { # type: ignore
2943
fld.name: get_polars_type_from_annotation(fld.annotation)
3044
for fld in stringify_model(schema).__fields__.values()

src/dve/core_engine/backends/implementations/spark/readers/xml.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
get_type_from_annotation,
1919
spark_write_parquet,
2020
)
21-
from dve.core_engine.backends.readers.xml import XMLStreamReader
21+
from dve.core_engine.backends.readers.xml import BasicXMLFileReader, XMLStreamReader
22+
from dve.core_engine.backends.readers.xml_linting import run_xmllint
2223
from dve.core_engine.type_hints import URI, EntityName
23-
from dve.parser.file_handling import get_content_length
24+
from dve.parser.file_handling import get_content_length, get_parent
2425
from dve.parser.file_handling.service import open_stream
26+
from dve.pipeline.utils import dump_errors
2527

2628
SparkXMLMode = Literal["PERMISSIVE", "FAILFAST", "DROPMALFORMED"]
2729
"""The mode to use when parsing XML files with Spark."""
@@ -51,7 +53,7 @@ def read_to_dataframe(
5153

5254

5355
@spark_write_parquet
54-
class SparkXMLReader(BaseFileReader): # pylint: disable=too-many-instance-attributes
56+
class SparkXMLReader(BasicXMLFileReader): # pylint: disable=too-many-instance-attributes
5557
"""A reader for XML files built atop Spark-XML."""
5658

5759
def __init__(
@@ -69,21 +71,31 @@ def __init__(
6971
sanitise_multiline: bool = True,
7072
namespace=None,
7173
trim_cells=True,
74+
xsd_location: Optional[URI] = None,
75+
xsd_error_code: Optional[str] = None,
76+
xsd_error_message: Optional[str] = None
7277
**_,
7378
) -> None:
74-
self.record_tag = record_tag
79+
80+
super().__init__(
81+
record_tag=record_tag,
82+
root_tag=root_tag,
83+
trim_cells=trim_cells,
84+
null_values=null_values,
85+
sanitise_multiline=sanitise_multiline,
86+
xsd_location=xsd_location,
87+
xsd_error_code=xsd_error_code,
88+
xsd_error_message=xsd_error_message
89+
)
90+
7591
self.spark_session = spark_session or SparkSession.builder.getOrCreate()
7692
self.sampling_ratio = sampling_ratio
7793
self.exclude_attribute = exclude_attribute
7894
self.mode = mode
7995
self.infer_schema = infer_schema
8096
self.ignore_namespace = ignore_namespace
81-
self.root_tag = root_tag
8297
self.sanitise_multiline = sanitise_multiline
83-
self.null_values = null_values
8498
self.namespace = namespace
85-
self.trim_cells = trim_cells
86-
super().__init__()
8799

88100
def read_to_py_iterator(
89101
self, resource: URI, entity_name: EntityName, schema: Type[BaseModel]
@@ -104,6 +116,16 @@ def read_to_dataframe(
104116
"""
105117
if get_content_length(resource) == 0:
106118
raise EmptyFileError(f"File at {resource} is empty.")
119+
120+
if self.xsd_location:
121+
msg = self._run_xmllint(file_uri=resource)
122+
if msg:
123+
working_folder = get_parent(resource)
124+
dump_errors(
125+
working_folder=working_folder,
126+
step_name="file_transformation",
127+
messages=[msg]
128+
)
107129

108130
spark_schema: StructType = get_type_from_annotation(schema)
109131
kwargs = {

src/dve/core_engine/backends/readers/xml.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
from dve.core_engine.backends.base.reader import BaseFileReader
1313
from dve.core_engine.backends.exceptions import EmptyFileError
14+
from dve.core_engine.backends.readers.xml_linting import run_xmllint
1415
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model
1516
from dve.core_engine.loggers import get_logger
17+
from dve.core_engine.message import FeedbackMessage
1618
from dve.core_engine.type_hints import URI, EntityName
1719
from dve.parser.file_handling import NonClosingTextIOWrapper, get_content_length, open_stream
1820
from dve.parser.file_handling.implementations.file import (
@@ -114,6 +116,9 @@ def __init__(
114116
sanitise_multiline: bool = True,
115117
encoding: str = "utf-8-sig",
116118
n_records_to_read: Optional[int] = None,
119+
xsd_location: Optional[URI] = None,
120+
xsd_error_code: Optional[str] = None,
121+
xsd_error_message: Optional[str] = None,
117122
**_,
118123
):
119124
"""Init function for the base XML reader.
@@ -148,6 +153,12 @@ def __init__(
148153
"""Encoding of the XML file."""
149154
self.n_records_to_read = n_records_to_read
150155
"""The maximum number of records to read from a document."""
156+
self.xsd_location = xsd_location
157+
"""The relative URI of the xsd file if wishing to perform xsd validation"""
158+
self.xsd_error_code = xsd_error_code
159+
"""The error code to be reported if xsd validation fails (if xsd)"""
160+
self.xsd_error_message = xsd_error_message
161+
"""The error message to be reported if xsd validation fails"""
151162
super().__init__()
152163
self._logger = get_logger(__name__)
153164

@@ -260,6 +271,12 @@ def _parse_xml(
260271

261272
for element in elements:
262273
yield self._parse_element(element, template_row)
274+
275+
def _run_xmllint(self, file_uri: URI) -> FeedbackMessage:
276+
return run_xmllint(file_uri=file_uri,
277+
schema_uri=self.xsd_location,
278+
error_code=self.xsd_error_code,
279+
error_message=self.xsd_error_message)
263280

264281
def read_to_py_iterator(
265282
self,
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Implement XML linting for files."""
2+
3+
import shutil
4+
import tempfile
5+
from contextlib import ExitStack
6+
from pathlib import Path
7+
from subprocess import PIPE, STDOUT, Popen
8+
from typing import Sequence, Union
9+
from uuid import uuid4
10+
11+
from dve.core_engine.message import FeedbackMessage
12+
from dve.parser.file_handling import (
13+
copy_resource,
14+
get_file_name,
15+
get_resource_exists,
16+
open_stream,
17+
)
18+
from dve.parser.file_handling.implementations.file import file_uri_to_local_path
19+
from dve.parser.type_hints import URI
20+
21+
ErrorMessage = str
22+
"""Error message for xml issues"""
23+
ErrorCode = str
24+
"""Error code for xml feedback errors"""
25+
26+
FIVE_MEBIBYTES = 5 * (1024**2)
27+
"""The size of 5 binary megabytes, in bytes."""
28+
29+
30+
def _ensure_schema_and_resources(
31+
schema_uri: URI, schema_resources: Sequence[URI], temp_dir: Path
32+
) -> Path:
33+
"""Given the schema and schema resource URIs and a temp dir, if the resources
34+
are remote or exist in different directories, copy them to the temp dir.
35+
36+
Return the local schema path.
37+
38+
"""
39+
if not get_resource_exists(schema_uri):
40+
raise IOError(f"No resource accessible at schema URI {schema_uri!r}")
41+
42+
missing_resources = list(
43+
filter(lambda resource: not get_resource_exists(resource), schema_resources)
44+
)
45+
if missing_resources:
46+
raise IOError(f"Some schema resources missing: {missing_resources!r}")
47+
48+
all_resources = [schema_uri, *schema_resources]
49+
50+
schemas_are_files = all(map(lambda resource: resource.startswith("file:"), all_resources))
51+
if schemas_are_files:
52+
paths = list(map(file_uri_to_local_path, all_resources))
53+
all_paths_have_same_parent = len({path.parent for path in paths}) == 1
54+
55+
if all_paths_have_same_parent:
56+
schema_path = paths[0]
57+
return schema_path
58+
59+
for resource_uri in all_resources:
60+
local_path = temp_dir.joinpath(get_file_name(resource_uri))
61+
copy_resource(resource_uri, local_path.as_uri())
62+
63+
schema_path = temp_dir.joinpath(get_file_name(schema_uri))
64+
return schema_path
65+
66+
67+
def run_xmllint(
68+
file_uri: URI,
69+
schema_uri: URI,
70+
*schema_resources: URI,
71+
error_code: ErrorCode,
72+
error_message: ErrorMessage,
73+
) -> Union[None, FeedbackMessage]:
74+
"""Run `xmllint`, given a file and information about the schemas to apply.
75+
76+
The schema and associated resources will be copied to a temporary directory
77+
for validation, unless they are all already in the same local folder.
78+
79+
Args:
80+
- `file_uri`: the URI of the file to be streamed into `xmllint`
81+
- `schema_uri`: the URI of the XSD schema for the file.
82+
- `*schema_resources`: URIs for additional XSD files required by the schema.
83+
- `error_code`: The error_code to use in FeedbackMessage if the linting fails.
84+
- `error_message`: The error_message to use in FeedbackMessage if the linting fails.
85+
86+
Returns a deque of messages produced by the linting.
87+
88+
"""
89+
if not shutil.which("xmllint"):
90+
raise OSError("Unable to find `xmllint` binary")
91+
92+
if not get_resource_exists(file_uri):
93+
raise IOError(f"No resource accessible at file URI {file_uri!r}")
94+
95+
# Ensure the schema and resources are local file paths so they can be
96+
# read by xmllint.
97+
# Lots of resources to manage here.
98+
with tempfile.TemporaryDirectory() as temp_dir_str:
99+
temp_dir = Path(temp_dir_str)
100+
schema_path = _ensure_schema_and_resources(schema_uri, schema_resources, temp_dir)
101+
message_file_path = temp_dir.joinpath(uuid4().hex)
102+
103+
with ExitStack() as linting_context:
104+
# Need to write lint output to a file to avoid deadlock. Kinder to mem this way anyway.
105+
message_file_bytes = linting_context.enter_context(message_file_path.open("wb"))
106+
107+
# Open an `xmllint` process to pipe into.
108+
command = ["xmllint", "--stream", "--schema", str(schema_path), "-"]
109+
process = linting_context.enter_context(
110+
Popen(command, stdin=PIPE, stdout=message_file_bytes, stderr=STDOUT)
111+
)
112+
# This should never trigger, bad typing in stdlib.
113+
if process.stdin is None:
114+
raise ValueError("Unable to pipe file into subprocess")
115+
116+
# Pipe the XML file contents into xmllint.
117+
block = b""
118+
try:
119+
with open_stream(file_uri, "rb") as byte_stream:
120+
while True:
121+
block = byte_stream.read(FIVE_MEBIBYTES)
122+
if not block:
123+
break
124+
process.stdin.write(block)
125+
except BrokenPipeError:
126+
pass
127+
finally:
128+
# Close the input stream and await the response code.
129+
# Output will be written to the message file.
130+
process.stdin.close()
131+
# TODO: Identify an appropriate timeout.
132+
return_code = process.wait()
133+
134+
if return_code == 0:
135+
return None
136+
137+
return FeedbackMessage(
138+
entity="xsd_validation",
139+
record={},
140+
failure_type="submission",
141+
is_informational=False,
142+
error_type="xsd check",
143+
error_location="Whole File",
144+
error_message=error_message,
145+
error_code=error_code,
146+
)

0 commit comments

Comments
 (0)