Skip to content

Commit fae698e

Browse files
committed
freshen cwltool/command_line_tool.py
1 parent 19756e7 commit fae698e

File tree

1 file changed

+61
-39
lines changed

1 file changed

+61
-39
lines changed

cwltool/command_line_tool.py

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import urllib
1414
from functools import cmp_to_key, partial
1515
from typing import (
16-
IO,
1716
Any,
1817
Callable,
1918
Dict,
@@ -24,6 +23,7 @@
2423
MutableSequence,
2524
Optional,
2625
Set,
26+
TextIO,
2727
Union,
2828
cast,
2929
)
@@ -60,6 +60,7 @@
6060
from .utils import (
6161
CWLObjectType,
6262
CWLOutputType,
63+
DirectoryType,
6364
JobsGeneratorType,
6465
OutputCallbackType,
6566
adjustDirObjs,
@@ -130,9 +131,20 @@ def run(
130131
try:
131132
normalizeFilesDirs(self.builder.job)
132133
ev = self.builder.do_eval(self.script)
133-
normalizeFilesDirs(ev)
134+
normalizeFilesDirs(
135+
cast(
136+
Optional[
137+
Union[
138+
MutableSequence[MutableMapping[str, Any]],
139+
MutableMapping[str, Any],
140+
DirectoryType,
141+
]
142+
],
143+
ev,
144+
)
145+
)
134146
if self.output_callback:
135-
self.output_callback(ev, "success")
147+
self.output_callback(cast(Optional[CWLObjectType], ev), "success")
136148
except WorkflowException as err:
137149
_logger.warning(
138150
"Failed to evaluate expression:\n%s",
@@ -385,7 +397,7 @@ def make_job_runner(self, runtimeContext: RuntimeContext) -> Type[JobBase]:
385397

386398
def make_path_mapper(
387399
self,
388-
reffiles: List[Any],
400+
reffiles: List[CWLObjectType],
389401
stagedir: str,
390402
runtimeContext: RuntimeContext,
391403
separateDirs: bool,
@@ -561,7 +573,7 @@ def calc_checksum(location: str) -> Optional[str]:
561573

562574
def update_status_output_callback(
563575
output_callbacks: OutputCallbackType,
564-
jobcachelock: IO[Any],
576+
jobcachelock: TextIO,
565577
outputs: Optional[CWLObjectType],
566578
processStatus: str,
567579
) -> None:
@@ -620,9 +632,11 @@ def update_status_output_callback(
620632

621633
initialWorkdir, _ = self.get_requirement("InitialWorkDirRequirement")
622634
if initialWorkdir is not None:
623-
ls = [] # type: List[Dict[str, Any]]
635+
ls = [] # type: List[CWLObjectType]
624636
if isinstance(initialWorkdir["listing"], str):
625-
ls = builder.do_eval(initialWorkdir["listing"])
637+
ls = cast(
638+
List[CWLObjectType], builder.do_eval(initialWorkdir["listing"])
639+
)
626640
else:
627641
for t in cast(
628642
MutableSequence[Union[str, CWLObjectType]],
@@ -644,13 +658,13 @@ def update_status_output_callback(
644658
if et["entry"] is not None:
645659
ls.append(et)
646660
else:
647-
initwd_item = builder.do_eval(cast(str, t))
661+
initwd_item = builder.do_eval(t)
648662
if not initwd_item:
649663
continue
650664
if isinstance(initwd_item, MutableSequence):
651-
ls.extend(initwd_item)
665+
ls.extend(cast(List[CWLObjectType], initwd_item))
652666
else:
653-
ls.append(initwd_item)
667+
ls.append(cast(CWLObjectType, initwd_item))
654668
for i, t2 in enumerate(ls):
655669
if "entry" in t2:
656670
if isinstance(t2["entry"], str):
@@ -663,10 +677,11 @@ def update_status_output_callback(
663677
else:
664678
if t2.get("entryname") or t2.get("writable"):
665679
t2 = copy.deepcopy(t2)
680+
t2entry = cast(CWLObjectType, t2["entry"])
666681
if t2.get("entryname"):
667-
t2["entry"]["basename"] = t2["entryname"]
668-
t2["entry"]["writable"] = t2.get("writable")
669-
ls[i] = t2["entry"]
682+
t2entry["basename"] = t2["entryname"]
683+
t2entry["writable"] = t2.get("writable")
684+
ls[i] = cast(CWLObjectType, t2["entry"])
670685
j.generatefiles["listing"] = ls
671686
for entry in ls:
672687
self.updatePathmap(builder.outdir, builder.pathmapper, entry)
@@ -689,13 +704,13 @@ def update_status_output_callback(
689704

690705
if self.tool.get("stdin"):
691706
with SourceLine(self.tool, "stdin", ValidationException, debug):
692-
j.stdin = builder.do_eval(self.tool["stdin"])
707+
j.stdin = cast(str, builder.do_eval(self.tool["stdin"]))
693708
if j.stdin:
694709
reffiles.append({"class": "File", "path": j.stdin})
695710

696711
if self.tool.get("stderr"):
697712
with SourceLine(self.tool, "stderr", ValidationException, debug):
698-
j.stderr = builder.do_eval(self.tool["stderr"])
713+
j.stderr = cast(str, builder.do_eval(self.tool["stderr"]))
699714
if j.stderr:
700715
if os.path.isabs(j.stderr) or ".." in j.stderr:
701716
raise ValidationException(
@@ -704,7 +719,7 @@ def update_status_output_callback(
704719

705720
if self.tool.get("stdout"):
706721
with SourceLine(self.tool, "stdout", ValidationException, debug):
707-
j.stdout = builder.do_eval(self.tool["stdout"])
722+
j.stdout = cast(str, builder.do_eval(self.tool["stdout"]))
708723
if j.stdout:
709724
if os.path.isabs(j.stdout) or ".." in j.stdout or not j.stdout:
710725
raise ValidationException(
@@ -738,21 +753,21 @@ def update_status_output_callback(
738753
j.inplace_update = cast(bool, inplaceUpdateReq["inplaceUpdate"])
739754
normalizeFilesDirs(j.generatefiles)
740755

741-
readers = {} # type: Dict[str, Any]
756+
readers = {} # type: Dict[str, CWLObjectType]
742757
muts = set() # type: Set[str]
743758

744759
if builder.mutation_manager is not None:
745760

746-
def register_mut(f): # type: (Dict[str, Any]) -> None
761+
def register_mut(f: CWLObjectType) -> None:
747762
mm = cast(MutationManager, builder.mutation_manager)
748-
muts.add(f["location"])
763+
muts.add(cast(str, f["location"]))
749764
mm.register_mutation(j.name, f)
750765

751-
def register_reader(f): # type: (Dict[str, Any]) -> None
766+
def register_reader(f: CWLObjectType) -> None:
752767
mm = cast(MutationManager, builder.mutation_manager)
753-
if f["location"] not in muts:
768+
if cast(str, f["location"]) not in muts:
754769
mm.register_reader(j.name, f)
755-
readers[f["location"]] = copy.deepcopy(f)
770+
readers[cast(str, f["location"])] = copy.deepcopy(f)
756771

757772
for li in j.generatefiles["listing"]:
758773
if li.get("writable") and j.inplace_update:
@@ -770,8 +785,9 @@ def register_reader(f): # type: (Dict[str, Any]) -> None
770785
timelimit, _ = self.get_requirement("ToolTimeLimit")
771786
if timelimit is not None:
772787
with SourceLine(timelimit, "timelimit", ValidationException, debug):
773-
j.timelimit = builder.do_eval(
774-
cast(Union[int, str], timelimit["timelimit"])
788+
j.timelimit = cast(
789+
Optional[int],
790+
builder.do_eval(cast(Union[int, str], timelimit["timelimit"])),
775791
)
776792
if not isinstance(j.timelimit, int) or j.timelimit < 0:
777793
raise WorkflowException(
@@ -781,8 +797,11 @@ def register_reader(f): # type: (Dict[str, Any]) -> None
781797
networkaccess, _ = self.get_requirement("NetworkAccess")
782798
if networkaccess is not None:
783799
with SourceLine(networkaccess, "networkAccess", ValidationException, debug):
784-
j.networkaccess = builder.do_eval(
785-
cast(Union[bool, str], networkaccess["networkAccess"])
800+
j.networkaccess = cast(
801+
bool,
802+
builder.do_eval(
803+
cast(Union[bool, str], networkaccess["networkAccess"])
804+
),
786805
)
787806
if not isinstance(j.networkaccess, bool):
788807
raise WorkflowException(
@@ -792,8 +811,10 @@ def register_reader(f): # type: (Dict[str, Any]) -> None
792811
j.environment = {}
793812
evr, _ = self.get_requirement("EnvVarRequirement")
794813
if evr is not None:
795-
for t2 in cast(List[Dict[str, str]], evr["envDef"]):
796-
j.environment[t2["envName"]] = builder.do_eval(t2["envValue"])
814+
for t3 in cast(List[Dict[str, str]], evr["envDef"]):
815+
j.environment[t3["envName"]] = cast(
816+
str, builder.do_eval(t3["envValue"])
817+
)
797818

798819
shellcmd, _ = self.get_requirement("ShellCommandRequirement")
799820
if shellcmd is not None:
@@ -822,13 +843,13 @@ def register_reader(f): # type: (Dict[str, Any]) -> None
822843

823844
def collect_output_ports(
824845
self,
825-
ports: Union[CommentedSeq, Set[Dict[str, Any]]],
846+
ports: Union[CommentedSeq, Set[CWLObjectType]],
826847
builder: Builder,
827848
outdir: str,
828849
rcode: int,
829850
compute_checksum: bool = True,
830851
jobname: str = "",
831-
readers: Optional[Dict[str, Any]] = None,
852+
readers: Optional[MutableMapping[str, CWLObjectType]] = None,
832853
) -> OutputPortsType:
833854
ret = {} # type: OutputPortsType
834855
debug = _logger.isEnabledFor(logging.DEBUG)
@@ -869,9 +890,7 @@ def collect_output_ports(
869890
if ret:
870891
revmap = partial(revmap_file, builder, outdir)
871892
adjustDirObjs(ret, trim_listing)
872-
visit_class(
873-
ret, ("File", "Directory"), cast(Callable[[Any], Any], revmap)
874-
)
893+
visit_class(ret, ("File", "Directory"), revmap)
875894
visit_class(ret, ("File", "Directory"), remove_path)
876895
normalizeFilesDirs(ret)
877896
visit_class(
@@ -915,7 +934,10 @@ def collect_output(
915934
empty_and_optional = False
916935
debug = _logger.isEnabledFor(logging.DEBUG)
917936
if "outputBinding" in schema:
918-
binding = cast(Dict[str, Any], schema["outputBinding"])
937+
binding = cast(
938+
MutableMapping[str, Union[bool, str, List[str]]],
939+
schema["outputBinding"],
940+
)
919941
globpatterns = [] # type: List[str]
920942

921943
revmap = partial(revmap_file, builder, outdir)
@@ -1015,8 +1037,8 @@ def collect_output(
10151037
if "outputEval" in binding:
10161038
with SourceLine(binding, "outputEval", WorkflowException, debug):
10171039
result = builder.do_eval(
1018-
binding["outputEval"], context=r
1019-
) # type: CWLOutputType
1040+
cast(CWLOutputType, binding["outputEval"]), context=r
1041+
)
10201042
else:
10211043
result = cast(CWLOutputType, r)
10221044

@@ -1088,7 +1110,7 @@ def collect_output(
10881110
if "format" in schema:
10891111
for primary in aslist(result):
10901112
primary["format"] = builder.do_eval(
1091-
cast(Union[str, List[str]], schema["format"]), context=primary
1113+
schema["format"], context=primary
10921114
)
10931115

10941116
# Ensure files point to local references outside of the run environment
@@ -1108,8 +1130,8 @@ def collect_output(
11081130
and schema["type"]["type"] == "record"
11091131
):
11101132
out = {}
1111-
for field in cast(List[Dict[str, Any]], schema["type"]["fields"]):
1112-
out[shortname(field["name"])] = self.collect_output(
1133+
for field in cast(List[CWLObjectType], schema["type"]["fields"]):
1134+
out[shortname(cast(str, field["name"]))] = self.collect_output(
11131135
field, builder, outdir, fs_access, compute_checksum=compute_checksum
11141136
)
11151137
return out

0 commit comments

Comments
 (0)