Skip to content

Commit e3e19f8

Browse files
authored
Make fetcher and resolver configurable. (#240)
1 parent c0cf74d commit e3e19f8

File tree

4 files changed

+120
-37
lines changed

4 files changed

+120
-37
lines changed

cwltool/load_tool.py

+54-20
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,33 @@
66
import logging
77
import re
88
import urlparse
9-
from schema_salad.ref_resolver import Loader
9+
10+
from schema_salad.ref_resolver import Loader, Fetcher, DefaultFetcher
1011
import schema_salad.validate as validate
1112
from schema_salad.validate import ValidationException
1213
import schema_salad.schema as schema
14+
import requests
15+
16+
from typing import Any, AnyStr, Callable, cast, Dict, Text, Tuple, Union
17+
1318
from avro.schema import Names
19+
1420
from . import update
1521
from . import process
1622
from .process import Process, shortname
1723
from .errors import WorkflowException
18-
from typing import Any, AnyStr, Callable, cast, Dict, Text, Tuple, Union
1924

2025
_logger = logging.getLogger("cwltool")
2126

22-
def fetch_document(argsworkflow, resolver=None):
23-
# type: (Union[Text, dict[Text, Any]], Any) -> Tuple[Loader, Dict[Text, Any], Text]
27+
def fetch_document(argsworkflow, # type: Union[Text, dict[Text, Any]]
28+
resolver=None, # type: Callable[[Loader, Union[Text, dict[Text, Any]]], Text]
29+
fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
30+
):
31+
# type: (...) -> Tuple[Loader, Dict[Text, Any], Text]
2432
"""Retrieve a CWL document."""
25-
document_loader = Loader({"cwl": "https://w3id.org/cwl/cwl#", "id": "@id"})
33+
34+
document_loader = Loader({"cwl": "https://w3id.org/cwl/cwl#", "id": "@id"},
35+
fetcher_constructor=fetcher_constructor)
2636

2737
uri = None # type: Text
2838
workflowobj = None # type: Dict[Text, Any]
@@ -95,16 +105,23 @@ def _convert_stdstreams_to_files(workflowobj):
95105
for entry in workflowobj:
96106
_convert_stdstreams_to_files(entry)
97107

98-
def validate_document(document_loader, workflowobj, uri,
99-
enable_dev=False, strict=True, preprocess_only=False):
100-
# type: (Loader, Dict[Text, Any], Text, bool, bool, bool) -> Tuple[Loader, Names, Union[Dict[Text, Any], List[Dict[Text, Any]]], Dict[Text, Any], Text]
108+
def validate_document(document_loader, # type: Loader
109+
workflowobj, # type: Dict[Text, Any]
110+
uri, # type: Text
111+
enable_dev=False, # type: bool
112+
strict=True, # type: bool
113+
preprocess_only=False, # type: bool
114+
fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
115+
):
116+
# type: (...) -> Tuple[Loader, Names, Union[Dict[Text, Any], List[Dict[Text, Any]]], Dict[Text, Any], Text]
101117
"""Validate a CWL document."""
118+
102119
jobobj = None
103120
if "cwl:tool" in workflowobj:
104121
jobobj, _ = document_loader.resolve_all(workflowobj, uri)
105122
uri = urlparse.urljoin(uri, workflowobj["https://w3id.org/cwl/cwl#tool"])
106123
del cast(dict, jobobj)["https://w3id.org/cwl/cwl#tool"]
107-
workflowobj = fetch_document(uri)[1]
124+
workflowobj = fetch_document(uri, fetcher_constructor=fetcher_constructor)[1]
108125

109126
if isinstance(workflowobj, list):
110127
workflowobj = {
@@ -130,12 +147,16 @@ def validate_document(document_loader, workflowobj, uri,
130147
workflowobj["$graph"] = workflowobj["@graph"]
131148
del workflowobj["@graph"]
132149

133-
(document_loader, avsc_names) = \
150+
(sch_document_loader, avsc_names) = \
134151
process.get_schema(workflowobj["cwlVersion"])[:2]
135152

136153
if isinstance(avsc_names, Exception):
137154
raise avsc_names
138155

156+
document_loader = Loader(sch_document_loader.ctx, schemagraph=sch_document_loader.graph,
157+
idx=document_loader.idx, cache=sch_document_loader.cache,
158+
fetcher_constructor=fetcher_constructor)
159+
139160
workflowobj["id"] = fileuri
140161
processobj, metadata = document_loader.resolve_all(workflowobj, fileuri)
141162
if not isinstance(processobj, (dict, list)):
@@ -165,8 +186,14 @@ def validate_document(document_loader, workflowobj, uri,
165186
return document_loader, avsc_names, processobj, metadata, uri
166187

167188

168-
def make_tool(document_loader, avsc_names, metadata, uri, makeTool, kwargs):
169-
# type: (Loader, Names, Dict[Text, Any], Text, Callable[..., Process], Dict[AnyStr, Any]) -> Process
189+
def make_tool(document_loader, # type: Loader
190+
avsc_names, # type: Names
191+
metadata, # type: Dict[Text, Any]
192+
uri, # type: Text
193+
makeTool, # type: Callable[..., Process]
194+
kwargs # type: dict
195+
):
196+
# type: (...) -> Process
170197
"""Make a Python CWL object."""
171198
resolveduri = document_loader.resolve_ref(uri)[0]
172199

@@ -179,8 +206,10 @@ def make_tool(document_loader, avsc_names, metadata, uri, makeTool, kwargs):
179206
"one of #%s" % ", #".join(
180207
urlparse.urldefrag(i["id"])[1] for i in resolveduri
181208
if "id" in i))
182-
else:
209+
elif isinstance(resolveduri, dict):
183210
processobj = resolveduri
211+
else:
212+
raise Exception("Must resolve to list or dict")
184213

185214
kwargs = kwargs.copy()
186215
kwargs.update({
@@ -200,14 +229,19 @@ def make_tool(document_loader, avsc_names, metadata, uri, makeTool, kwargs):
200229
return tool
201230

202231

203-
def load_tool(argsworkflow, makeTool, kwargs=None,
204-
enable_dev=False,
205-
strict=True,
206-
resolver=None):
207-
# type: (Union[Text, dict[Text, Any]], Callable[...,Process], Dict[AnyStr, Any], bool, bool, Any) -> Any
208-
document_loader, workflowobj, uri = fetch_document(argsworkflow, resolver=resolver)
232+
def load_tool(argsworkflow, # type: Union[Text, Dict[Text, Any]]
233+
makeTool, # type: Callable[..., Process]
234+
kwargs=None, # type: dict
235+
enable_dev=False, # type: bool
236+
strict=True, # type: bool
237+
resolver=None, # type: Callable[[Loader, Union[Text, dict[Text, Any]]], Text]
238+
fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
239+
):
240+
# type: (...) -> Process
241+
242+
document_loader, workflowobj, uri = fetch_document(argsworkflow, resolver=resolver, fetcher_constructor=fetcher_constructor)
209243
document_loader, avsc_names, processobj, metadata, uri = validate_document(
210244
document_loader, workflowobj, uri, enable_dev=enable_dev,
211-
strict=strict)
245+
strict=strict, fetcher_constructor=fetcher_constructor)
212246
return make_tool(document_loader, avsc_names, metadata, uri,
213247
makeTool, kwargs if kwargs else {})

cwltool/main.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
import functools
1515

1616
import rdflib
17+
import requests
1718
from typing import (Union, Any, AnyStr, cast, Callable, Dict, Sequence, Text,
1819
Tuple, Type, IO)
1920

20-
from schema_salad.ref_resolver import Loader
21+
from schema_salad.ref_resolver import Loader, Fetcher, DefaultFetcher
2122
import schema_salad.validate as validate
2223
import schema_salad.jsonld_context
2324
import schema_salad.makedoc
@@ -392,7 +393,7 @@ def generate_parser(toolparser, tool, namemap, records):
392393

393394
def load_job_order(args, t, stdin, print_input_deps=False, relative_deps=False,
394395
stdout=sys.stdout, make_fs_access=None):
395-
# type: (argparse.Namespace, Process, IO[Any], bool, bool, IO[Any], Type[StdFsAccess]) -> Union[int, Tuple[Dict[Text, Any], Text]]
396+
# type: (argparse.Namespace, Process, IO[Any], bool, bool, IO[Any], Callable[[Text], StdFsAccess]) -> Union[int, Tuple[Dict[Text, Any], Text]]
396397

397398
job_order_object = None
398399

@@ -553,18 +554,21 @@ def versionstring():
553554
return u"%s %s" % (sys.argv[0], "unknown version")
554555

555556

556-
def main(argsl=None,
557-
args=None,
558-
executor=single_job_executor,
559-
makeTool=workflow.defaultMakeTool,
560-
selectResources=None,
561-
stdin=sys.stdin,
562-
stdout=sys.stdout,
563-
stderr=sys.stderr,
564-
versionfunc=versionstring,
565-
job_order_object=None,
566-
make_fs_access=StdFsAccess):
567-
# type: (List[str], argparse.Namespace, Callable[..., Union[Text, Dict[Text, Text]]], Callable[..., Process], Callable[[Dict[Text, int]], Dict[Text, int]], IO[Any], IO[Any], IO[Any], Callable[[], Text], Union[int, Tuple[Dict[Text, Any], Text]], Type[StdFsAccess]) -> int
557+
def main(argsl=None, # type: List[str]
558+
args=None, # type: argparse.Namespace
559+
executor=single_job_executor, # type: Callable[..., Union[Text, Dict[Text, Text]]]
560+
makeTool=workflow.defaultMakeTool, # type: Callable[..., Process]
561+
selectResources=None, # type: Callable[[Dict[Text, int]], Dict[Text, int]]
562+
stdin=sys.stdin, # type: IO[Any]
563+
stdout=sys.stdout, # type: IO[Any]
564+
stderr=sys.stderr, # type: IO[Any]
565+
versionfunc=versionstring, # type: Callable[[], Text]
566+
job_order_object=None, # type: Union[Tuple[Dict[Text, Any], Text], int]
567+
make_fs_access=StdFsAccess, # type: Callable[[Text], StdFsAccess]
568+
fetcher_constructor=DefaultFetcher, # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
569+
resolver=tool_resolver
570+
):
571+
# type: (...) -> int
568572

569573
_logger.removeHandler(defaultStreamHandler)
570574
stderr_handler = logging.StreamHandler(stderr)
@@ -624,7 +628,7 @@ def main(argsl=None,
624628
draft2tool.ACCEPTLIST_RE = draft2tool.ACCEPTLIST_EN_RELAXED_RE
625629

626630
try:
627-
document_loader, workflowobj, uri = fetch_document(args.workflow, resolver=tool_resolver)
631+
document_loader, workflowobj, uri = fetch_document(args.workflow, resolver=resolver, fetcher_constructor=fetcher_constructor)
628632

629633
if args.print_deps:
630634
printdeps(workflowobj, document_loader, stdout, args.relative_deps, uri)
@@ -633,7 +637,8 @@ def main(argsl=None,
633637
document_loader, avsc_names, processobj, metadata, uri \
634638
= validate_document(document_loader, workflowobj, uri,
635639
enable_dev=args.enable_dev, strict=args.strict,
636-
preprocess_only=args.print_pre or args.pack)
640+
preprocess_only=args.print_pre or args.pack,
641+
fetcher_constructor=fetcher_constructor)
637642

638643
if args.pack:
639644
stdout.write(print_pack(document_loader, processobj, uri, metadata))

cwltool/stdfsaccess.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import glob
44
import os
55

6-
76
class StdFsAccess(object):
87

98
def __init__(self, basedir): # type: (Text) -> None

tests/test_fetch.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import unittest
2+
import schema_salad.ref_resolver
3+
import schema_salad.main
4+
import schema_salad.schema
5+
from schema_salad.jsonld_context import makerdf
6+
from pkg_resources import Requirement, resource_filename, ResolutionError # type: ignore
7+
import rdflib
8+
import ruamel.yaml as yaml
9+
import json
10+
import os
11+
12+
from cwltool.main import main
13+
from cwltool.workflow import defaultMakeTool
14+
from cwltool.load_tool import load_tool
15+
16+
class FetcherTest(unittest.TestCase):
17+
def test_fetcher(self):
18+
class TestFetcher(schema_salad.ref_resolver.Fetcher):
19+
def __init__(self, a, b):
20+
pass
21+
22+
def fetch_text(self, url): # type: (unicode) -> unicode
23+
if url == "baz:bar/foo.cwl":
24+
return """
25+
cwlVersion: v1.0
26+
class: CommandLineTool
27+
baseCommand: echo
28+
inputs: []
29+
outputs: []
30+
"""
31+
else:
32+
raise RuntimeError("Not foo.cwl")
33+
34+
def check_exists(self, url): # type: (unicode) -> bool
35+
if url == "baz:bar/foo.cwl":
36+
return True
37+
else:
38+
return False
39+
40+
def test_resolver(d, a):
41+
return "baz:bar/" + a
42+
43+
load_tool("foo.cwl", defaultMakeTool, resolver=test_resolver, fetcher_constructor=TestFetcher)
44+
45+
self.assertEquals(0, main(["--print-pre", "--debug", "foo.cwl"], resolver=test_resolver, fetcher_constructor=TestFetcher))

0 commit comments

Comments
 (0)