Skip to content

Commit

Permalink
Added type hinting to resource.py #938 (#4463)
Browse files Browse the repository at this point in the history
* Added type hinting to resource.py, resolves #938

- Type hinting to functions in resource.py
- Added return type to inVirtualEnv() in __init__.py
- Added None checks to some function bodies

* Fix issues in type hinting PR

- Specify functions that return `Any` to return more specific types
- Import certain modules unconditionally
- Removed all `:rtype` comments
- `FileResource`, functions return binary streams and open files in
  binary mode
    - `ModuleDescriptor` open calls unchanged
- `ModuleDescriptor` specifies member types
  - `__file__` checks for `None`, `forModule()` and `_runningOnWorker()`

* Add type hinting to resource.py, #938

Same as pull #4463

* Remove redundant byte type check

* Fix checks

- moved `__file__` check into try block so `AttributeError` can be
  caught

* change mypy type checked files

* Update issues/938-type-resource from main repo to personal

* Fix check here too

---------

Co-authored-by: Adam Novak <[email protected]>
  • Loading branch information
stxue1 and adamnovak authored May 17, 2023
1 parent 28758cf commit 4318cb8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 49 deletions.
1 change: 0 additions & 1 deletion contrib/admin/mypy-with-ignore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def main():
'src/toil/job.py',
'src/toil/leader.py',
'src/toil/__init__.py',
'src/toil/resource.py',
'src/toil/deferred.py',
'src/toil/version.py',
'src/toil/wdl/utils.py',
Expand Down
2 changes: 1 addition & 1 deletion src/toil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def toilPackageDirPath() -> str:
return result


def inVirtualEnv():
def inVirtualEnv() -> bool:
"""Test if we are inside a virtualenv or Conda virtual environment."""
return ('VIRTUAL_ENV' in os.environ or
'CONDA_DEFAULT_ENV' in os.environ or
Expand Down
109 changes: 62 additions & 47 deletions src/toil/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,30 @@
from io import BytesIO
from pydoc import locate
from tempfile import mkdtemp
from typing import Any
from urllib.error import HTTPError
from urllib.request import urlopen
from zipfile import ZipFile

from typing import (TYPE_CHECKING,
Optional,
Callable,
IO,
Type,
Sequence,
BinaryIO)

from toil import inVirtualEnv
from toil.lib.iterables import concat
from toil.lib.memoize import strict_bool
from toil.lib.retry import ErrorCondition, retry
from toil.version import exactPython

logger = logging.getLogger(__name__)
from types import ModuleType

if TYPE_CHECKING:
from toil.jobStores.abstractJobStore import AbstractJobStore

logger = logging.getLogger(__name__)

class Resource(namedtuple('Resource', ('name', 'pathHash', 'url', 'contentHash'))):
"""
Expand All @@ -63,7 +74,7 @@ class Resource(namedtuple('Resource', ('name', 'pathHash', 'url', 'contentHash')
rootDirPathEnvName = resourceEnvNamePrefix + 'ROOT'

@classmethod
def create(cls, jobStore, leaderPath):
def create(cls, jobStore: "AbstractJobStore", leaderPath: str) -> "Resource":
"""
Saves the content of the file or directory at the given path to the given job store
and returns a resource object representing that content for the purpose of obtaining it
Expand All @@ -72,8 +83,6 @@ def create(cls, jobStore, leaderPath):
:param toil.jobStores.abstractJobStore.AbstractJobStore jobStore:
:param str leaderPath:
:rtype: Resource
"""
pathHash = cls._pathHash(leaderPath)
contentHash = hashlib.md5()
Expand All @@ -88,7 +97,7 @@ def create(cls, jobStore, leaderPath):
url=jobStore.getSharedPublicUrl(sharedFileName=pathHash),
contentHash=contentHash.hexdigest())

def refresh(self, jobStore):
def refresh(self, jobStore: "AbstractJobStore") -> "Resource":
return type(self)(name=self.name,
pathHash=self.pathHash,
url=jobStore.get_shared_public_url(shared_file_name=self.pathHash),
Expand Down Expand Up @@ -120,12 +129,12 @@ def cleanSystem(cls) -> None:
if k.startswith(cls.resourceEnvNamePrefix):
os.environ.pop(k)

def register(self):
def register(self) -> None:
"""Register this resource for later retrieval via lookup(), possibly in a child process."""
os.environ[self.resourceEnvNamePrefix + self.pathHash] = self.pickle()

@classmethod
def lookup(cls, leaderPath: str) -> "Resource":
def lookup(cls, leaderPath: str) -> Optional["Resource"]:
"""
Return a resource object representing a resource created from a file or directory at the given path on the leader.
Expand All @@ -146,7 +155,7 @@ def lookup(cls, leaderPath: str) -> "Resource":
assert self.pathHash == pathHash
return self

def download(self, callback=None):
def download(self, callback: Optional[Callable[[str], None]] = None) -> None:
"""
Download this resource from its URL to a file on the local system.
Expand All @@ -171,7 +180,7 @@ def download(self, callback=None):
raise

@property
def localPath(self):
def localPath(self) -> str:
"""
Get the path to resource on the worker.
Expand All @@ -181,39 +190,38 @@ def localPath(self):
raise NotImplementedError

@property
def localDirPath(self):
def localDirPath(self) -> str:
"""
The path to the directory containing the resource on the worker.
"""
rootDirPath = os.environ[self.rootDirPathEnvName]
return os.path.join(rootDirPath, self.contentHash)

def pickle(self):
def pickle(self) -> str:
return self.__class__.__module__ + "." + self.__class__.__name__ + ':' + json.dumps(self)

@classmethod
def unpickle(cls, s) -> "Resource":
def unpickle(cls, s: str) -> "Resource":
className, _json = s.split(':', 1)
return locate(className)(*json.loads(_json))
return locate(className)(*json.loads(_json)) # type: ignore

@classmethod
def _pathHash(cls, path):
def _pathHash(cls, path: str) -> str:
return hashlib.md5(path.encode('utf-8')).hexdigest()

@classmethod
def _load(cls, path):
def _load(cls, path: str) -> IO[bytes]:
"""
Returns a readable file-like object for the given path. If the path refers to a regular
file, this method returns the result of invoking open() on the given path. If the path
refers to a directory, this method returns a ZIP file with all files and subdirectories
in the directory at the given path.
:type path: str
:rtype: io.FileIO
"""
raise NotImplementedError()

def _save(self, dirPath):
def _save(self, dirPath: str) -> None:
"""
Save this resource to the directory at the given parent path.
Expand All @@ -226,7 +234,7 @@ def _save(self, dirPath):
error=HTTPError,
error_codes=[400])
])
def _download(self, dstFile):
def _download(self, dstFile: IO[bytes]) -> None:
"""
Download this resource from its URL to the given file object.
Expand All @@ -243,15 +251,15 @@ class FileResource(Resource):
"""A resource read from a file on the leader."""

@classmethod
def _load(cls, path):
return open(path)
def _load(cls, path: str) -> BinaryIO:
return open(path, 'rb')

def _save(self, dirPath):
with open(os.path.join(dirPath, self.name), mode='w') as localFile:
def _save(self, dirPath: str) -> None:
with open(os.path.join(dirPath, self.name), mode='wb') as localFile:
self._download(localFile)

@property
def localPath(self):
def localPath(self) -> str:
return os.path.join(self.localDirPath, self.name)


Expand Down Expand Up @@ -294,15 +302,15 @@ def _load(cls, path: str) -> BytesIO:
bytesIO.seek(0)
return bytesIO

def _save(self, dirPath):
def _save(self, dirPath: str) -> None:
bytesIO = BytesIO()
self._download(bytesIO)
bytesIO.seek(0)
with ZipFile(file=bytesIO, mode='r') as zipFile:
zipFile.extractall(path=dirPath)

@property
def localPath(self):
def localPath(self) -> str:
return self.localDirPath


Expand Down Expand Up @@ -367,9 +375,11 @@ class ModuleDescriptor(namedtuple('ModuleDescriptor', ('dirPath', 'name', 'fromV
Clean up
>>> rmtree( dirPath )
"""
dirPath: str
name: str

@classmethod
def forModule(cls, name: str) -> Any:
def forModule(cls, name: str) -> "ModuleDescriptor":
"""
Return an instance of this class representing the module of the given name.
Expand All @@ -378,8 +388,10 @@ def forModule(cls, name: str) -> Any:
method assumes that the module with the specified name has already been loaded.
"""
module = sys.modules[name]
filePath = os.path.abspath(module.__file__)
filePath = filePath.split(os.path.sep)
if module.__file__ is None:
raise Exception(f'Module {name} does not exist.')
fileAbsPath = os.path.abspath(module.__file__)
filePath = fileAbsPath.split(os.path.sep)
filePath[-1], extension = os.path.splitext(filePath[-1])
if extension not in (".py", ".pyc"):
raise Exception("The name of a user script/module must end in .py or .pyc.")
Expand All @@ -389,12 +401,12 @@ def forModule(cls, name: str) -> Any:
if module.__package__:
# Invoked as a module via python -m foo.bar
logger.debug("Script was invoked as a module")
name = [filePath.pop()]
nameList = [filePath.pop()]
for package in reversed(module.__package__.split('.')):
dirPathTail = filePath.pop()
assert dirPathTail == package
name.append(dirPathTail)
name = '.'.join(reversed(name))
nameList.append(dirPathTail)
name = '.'.join(reversed(nameList))
dirPath = os.path.sep.join(filePath)
else:
# Invoked as a script via python foo/bar.py
Expand All @@ -420,7 +432,7 @@ def forModule(cls, name: str) -> Any:
return cls(dirPath=dirPath, name=name, fromVirtualEnv=fromVirtualEnv)

@classmethod
def _check_conflict(cls, dirPath, name):
def _check_conflict(cls, dirPath: str, name: str) -> None:
"""
Check whether the module of the given name conflicts with another module on the sys.path.
Expand All @@ -442,13 +454,13 @@ def _check_conflict(cls, dirPath, name):
sys.path = old_sys_path

@property
def belongsToToil(self):
def belongsToToil(self) -> bool:
"""
True if this module is part of the Toil distribution
"""
return self.name.startswith('toil.')

def saveAsResourceTo(self, jobStore) -> Resource:
def saveAsResourceTo(self, jobStore: "AbstractJobStore") -> Resource:
"""
Store the file containing this module--or even the Python package directory hierarchy
containing that file--as a resource to the given job store and return the
Expand All @@ -458,10 +470,11 @@ def saveAsResourceTo(self, jobStore) -> Resource:
"""
return self._getResourceClass().create(jobStore, self._resourcePath)

def _getResourceClass(self):
def _getResourceClass(self) -> Type[Resource]:
"""
Return the concrete subclass of Resource that's appropriate for auto-deploying this module.
"""
subcls: Type[Resource]
if self.fromVirtualEnv:
subcls = VirtualEnvResource
elif os.path.isdir(self._resourcePath):
Expand All @@ -474,7 +487,7 @@ def _getResourceClass(self):
raise AssertionError("No such file or directory: '%s'" % self._resourcePath)
return subcls

def localize(self) -> Resource:
def localize(self) -> "ModuleDescriptor":
"""
Check if this module was saved as a resource.
Expand All @@ -488,7 +501,7 @@ def localize(self) -> Resource:
if resource is None:
return self
else:
def stash(tmpDirPath):
def stash(tmpDirPath: str) -> None:
# Save the original dirPath such that we can restore it in globalize()
with open(os.path.join(tmpDirPath, '.stash'), 'w') as f:
f.write('1' if self.fromVirtualEnv else '0')
Expand All @@ -499,7 +512,7 @@ def stash(tmpDirPath):
name=self.name,
fromVirtualEnv=self.fromVirtualEnv)

def _runningOnWorker(self):
def _runningOnWorker(self) -> bool:
try:
mainModule = sys.modules['__main__']
except KeyError:
Expand All @@ -511,6 +524,8 @@ def _runningOnWorker(self):
# we can reasonably assume that we are not running
# on a worker node.
try:
if mainModule.__file__ is None:
return False
mainModuleFile = os.path.basename(mainModule.__file__)
except AttributeError:
return False
Expand All @@ -537,7 +552,7 @@ def globalize(self) -> "ModuleDescriptor":
fromVirtualEnv=fromVirtualEnv)

@property
def _resourcePath(self):
def _resourcePath(self) -> str:
"""
The path to the directory that should be used when shipping this module and its siblings
around as a resource.
Expand All @@ -557,35 +572,35 @@ def _resourcePath(self):
return self.dirPath

@classmethod
def _initModuleName(cls, dirPath):
def _initModuleName(cls, dirPath: str) -> Optional[str]:
for name in ('__init__.py', '__init__.pyc', '__init__.pyo'):
if os.path.exists(os.path.join(dirPath, name)):
return name
return None

def _rootPackage(self):
def _rootPackage(self) -> str:
try:
head, tail = self.name.split('.', 1)
except ValueError:
raise ValueError('%r is stand-alone module.' % self)
raise ValueError('%r is stand-alone module.' % self.__repr__())
else:
return head

def toCommand(self):
def toCommand(self) -> Sequence[str]:
return tuple(map(str, self))

@classmethod
def fromCommand(cls, command):
def fromCommand(cls, command: Sequence[str]) -> "ModuleDescriptor":
assert len(command) == 3
return cls(dirPath=command[0], name=command[1], fromVirtualEnv=strict_bool(command[2]))

def makeLoadable(self):
def makeLoadable(self) -> "ModuleDescriptor":
module = self if self.belongsToToil else self.localize()
if module.dirPath not in sys.path:
sys.path.append(module.dirPath)
return module

def load(self):
def load(self) -> Optional[ModuleType]:
module = self.makeLoadable()
try:
return importlib.import_module(module.name)
Expand Down

0 comments on commit 4318cb8

Please sign in to comment.