-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
add vista2d plugins
Signed-off-by: binliu <[email protected]>
Showing
2 changed files
with
904 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,813 @@ | ||
################################# | ||
# | ||
# Imports from useful Python libraries | ||
# | ||
################################# | ||
|
||
import cgi | ||
import http.client | ||
import json | ||
import logging | ||
import mimetypes | ||
import os | ||
import re | ||
import ssl | ||
import tempfile | ||
from pathlib import Path | ||
from urllib.parse import quote_plus, unquote, urlencode, urlparse | ||
|
||
import requests | ||
import skimage | ||
from cellprofiler_core.module.image_segmentation import ImageSegmentation | ||
from cellprofiler_core.object import Objects | ||
from cellprofiler_core.setting.choice import Choice | ||
from cellprofiler_core.setting.text import Text | ||
|
||
################################# | ||
# | ||
# Imports from CellProfiler | ||
# | ||
################################## | ||
|
||
|
||
VISTA_link = "https://doi.org/10.48550/arXiv.2406.05285" | ||
LOGGER = logging.getLogger(__name__) | ||
|
||
__doc__ = f"""\ | ||
RunVISTA2D | ||
=========== | ||
**RunVISTA2D** uses a pre-trained VISTA2D model to detect cells in an image. | ||
This module is useful for automating simple segmentation tasks in CellProfiler. | ||
The module accepts tiff input images and produces an object set. | ||
This module is a client/frontend of a MONAI Label server. A VISTA2D based MONAI Label server needs to be set up and the address of the server needs to be passed to | ||
this module, before running. | ||
Installation: | ||
This module has no external dependencies other than the python(>3.8) build-in dependencies. | ||
You'll need to set up the VISTA2D based MONAI Label server based on the tutorial https://github.com/Project-MONAI/MONAILabel/tree/main/plugins/cellprofiler. After setting up the server, please | ||
provide the server address to this plugin.s | ||
Yufan He, Pengfei Guo, Yucheng Tang, Andriy Myronenko, Vishwesh Nath, Ziyue Xu, Dong Yang, Can Zhao, Benjamin Simon, Mason Belue, Stephanie Harmon, Baris Turkbey, Daguang Xu, & Wenqi Li. (2024). VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography.{VISTA_link} | ||
============ ============ =============== | ||
Supports 2D? Supports 3D? Respects masks? | ||
============ ============ =============== | ||
YES No NO | ||
============ ============ =============== | ||
""" | ||
|
||
|
||
def bytes_to_str(b): | ||
return b.decode("utf-8") if isinstance(b, bytes) else b | ||
|
||
|
||
class MONAILabelClient: | ||
""" | ||
Basic MONAILabel Client to invoke infer/train APIs over http/https | ||
""" | ||
|
||
def __init__(self, server_url, tmpdir=None, client_id=None): | ||
""" | ||
:param server_url: Server URL for MONAILabel (e.g. http://127.0.0.1:8000) | ||
:param tmpdir: Temp directory to save temporary files. If None then it uses tempfile.tempdir | ||
:param client_id: Client ID that will be added for all basic requests | ||
""" | ||
|
||
self._server_url = server_url.rstrip("/").strip() | ||
self._tmpdir = tmpdir if tmpdir else tempfile.tempdir if tempfile.tempdir else "/tmp" | ||
self._client_id = client_id | ||
self._headers = {} | ||
|
||
def _update_client_id(self, params): | ||
if params: | ||
params["client_id"] = self._client_id | ||
else: | ||
params = {"client_id": self._client_id} | ||
return params | ||
|
||
def update_auth(self, token): | ||
if token: | ||
self._headers["Authorization"] = f"{token['token_type']} {token['access_token']}" | ||
|
||
def get_server_url(self): | ||
""" | ||
Return server url | ||
:return: the url for monailabel server | ||
""" | ||
return self._server_url | ||
|
||
def set_server_url(self, server_url): | ||
""" | ||
Set url for monailabel server | ||
:param server_url: server url for monailabel | ||
""" | ||
self._server_url = server_url.rstrip("/").strip() | ||
|
||
def auth_enabled(self) -> bool: | ||
""" | ||
Check if Auth is enabled | ||
""" | ||
selector = "/auth/" | ||
status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector) | ||
if status != 200: | ||
return False | ||
|
||
response = bytes_to_str(response) | ||
LOGGER.debug(f"Response: {response}") | ||
enabled = json.loads(response).get("enabled", False) | ||
return True if enabled else False | ||
|
||
def auth_token(self, username, password): | ||
""" | ||
Fetch Auth Token. Currently only basic authentication is supported. | ||
:param username: UserName for basic authentication | ||
:param password: Password for basic authentication | ||
""" | ||
selector = "/auth/token" | ||
data = urlencode({"username": username, "password": password, "grant_type": "password"}) | ||
status, response, _, _ = MONAILabelUtils.http_method( | ||
"POST", self._server_url, selector, data, None, "application/x-www-form-urlencoded" | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
LOGGER.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def auth_valid_token(self) -> bool: | ||
selector = "/auth/token/valid" | ||
status, _, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) | ||
return True if status == 200 else False | ||
|
||
def info(self): | ||
""" | ||
Invoke /info/ request over MONAILabel Server | ||
:return: json response | ||
""" | ||
selector = "/info/" | ||
status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def next_sample(self, strategy, params): | ||
""" | ||
Get Next sample | ||
:param strategy: Name of strategy to be used for fetching next sample | ||
:param params: Additional JSON params as part of strategy request | ||
:return: json response which contains information about next image selected for annotation | ||
""" | ||
params = self._update_client_id(params) | ||
selector = f"/activelearning/{MONAILabelUtils.urllib_quote_plus(strategy)}" | ||
status, response, _, _ = MONAILabelUtils.http_method( | ||
"POST", self._server_url, selector, params, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def create_session(self, image_in, params=None): | ||
""" | ||
Create New Session | ||
:param image_in: filepath for image to be sent to server as part of session creation | ||
:param params: additional JSON params as part of session reqeust | ||
:return: json response which contains session id and other details | ||
""" | ||
selector = "/session/" | ||
params = self._update_client_id(params) | ||
|
||
status, response, _ = MONAILabelUtils.http_upload( | ||
"PUT", self._server_url, selector, params, [image_in], headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def get_session(self, session_id): | ||
""" | ||
Get Session | ||
:param session_id: Session Id | ||
:return: json response which contains more details about the session | ||
""" | ||
selector = f"/session/{MONAILabelUtils.urllib_quote_plus(session_id)}" | ||
status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def remove_session(self, session_id): | ||
""" | ||
Remove any existing Session | ||
:param session_id: Session Id | ||
:return: json response | ||
""" | ||
selector = f"/session/{MONAILabelUtils.urllib_quote_plus(session_id)}" | ||
status, response, _, _ = MONAILabelUtils.http_method( | ||
"DELETE", self._server_url, selector, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def upload_image(self, image_in, image_id=None, params=None): | ||
""" | ||
Upload New Image to MONAILabel Datastore | ||
:param image_in: Image File Path | ||
:param image_id: Force Image ID; If not provided then Server it auto generate new Image ID | ||
:param params: Additional JSON params | ||
:return: json response which contains image id and other details | ||
""" | ||
selector = f"/datastore/?image={MONAILabelUtils.urllib_quote_plus(image_id)}" | ||
|
||
files = {"file": image_in} | ||
params = self._update_client_id(params) | ||
fields = {"params": json.dumps(params) if params else "{}"} | ||
|
||
status, response, _, _ = MONAILabelUtils.http_multipart( | ||
"PUT", self._server_url, selector, fields, files, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(response)}", | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def save_label(self, image_id, label_in, tag="", params=None): | ||
""" | ||
Save/Submit Label | ||
:param image_id: Image Id for which label needs to saved/submitted | ||
:param label_in: Label File path which shall be saved/submitted | ||
:param tag: Save label against tag in datastore | ||
:param params: Additional JSON params for the request | ||
:return: json response | ||
""" | ||
selector = f"/datastore/label?image={MONAILabelUtils.urllib_quote_plus(image_id)}" | ||
if tag: | ||
selector += f"&tag={MONAILabelUtils.urllib_quote_plus(tag)}" | ||
|
||
params = self._update_client_id(params) | ||
fields = { | ||
"params": json.dumps(params), | ||
} | ||
files = {"label": label_in} | ||
|
||
status, response, _, _ = MONAILabelUtils.http_multipart( | ||
"PUT", self._server_url, selector, fields, files, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(response)}", | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def datastore(self): | ||
selector = "/datastore/?output=all" | ||
status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def download_label(self, label_id, tag): | ||
selector = "/datastore/label?label={}&tag={}".format( | ||
MONAILabelUtils.urllib_quote_plus(label_id), MONAILabelUtils.urllib_quote_plus(tag) | ||
) | ||
status, response, _, headers = MONAILabelUtils.http_method( | ||
"GET", self._server_url, selector, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response | ||
) | ||
|
||
content_disposition = headers.get("content-disposition") | ||
|
||
if not content_disposition: | ||
logging.warning("Filename not found. Fall back to no loaded labels") | ||
file_name = MONAILabelUtils.get_filename(content_disposition) | ||
|
||
file_ext = "".join(Path(file_name).suffixes) | ||
local_filename = tempfile.NamedTemporaryFile(dir=self._tmpdir, suffix=file_ext).name | ||
with open(local_filename, "wb") as f: | ||
f.write(response) | ||
|
||
return local_filename | ||
|
||
def infer(self, model, image_id, params, label_in=None, file=None, session_id=None): | ||
""" | ||
Run Infer | ||
:param model: Name of Model | ||
:param image_id: Image Id | ||
:param params: Additional configs/json params as part of Infer request | ||
:param label_in: File path for label mask which is needed to run Inference (e.g. In case of Scribbles) | ||
:param file: File path for Image (use raw image instead of image_id) | ||
:param session_id: Session ID (use existing session id instead of image_id) | ||
:return: response_file (label mask), response_body (json result/output params) | ||
""" | ||
selector = "/infer/{}?image={}".format( | ||
MONAILabelUtils.urllib_quote_plus(model), | ||
MONAILabelUtils.urllib_quote_plus(image_id), | ||
) | ||
if session_id: | ||
selector += f"&session_id={MONAILabelUtils.urllib_quote_plus(session_id)}" | ||
|
||
params = self._update_client_id(params) | ||
fields = {"params": json.dumps(params) if params else "{}"} | ||
files = {"label": label_in} if label_in else {} | ||
files.update({"file": file} if file and not session_id else {}) | ||
|
||
status, form, files, _ = MONAILabelUtils.http_multipart( | ||
"POST", self._server_url, selector, fields, files, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(form)}", | ||
) | ||
|
||
form = json.loads(form) if isinstance(form, str) else form | ||
params = form.get("params") if files else form | ||
params = json.loads(params) if isinstance(params, str) else params | ||
|
||
image_out = MONAILabelUtils.save_result(files, self._tmpdir) | ||
return image_out, params | ||
|
||
def wsi_infer(self, model, image_id, body=None, output="dsa", session_id=None): | ||
""" | ||
Run WSI Infer in case of Pathology App | ||
:param model: Name of Model | ||
:param image_id: Image Id | ||
:param body: Additional configs/json params as part of Infer request | ||
:param output: Output File format (dsa|asap|json) | ||
:param session_id: Session ID (use existing session id instead of image_id) | ||
:return: response_file (None), response_body | ||
""" | ||
selector = "/infer/wsi/{}?image={}".format( | ||
MONAILabelUtils.urllib_quote_plus(model), | ||
MONAILabelUtils.urllib_quote_plus(image_id), | ||
) | ||
if session_id: | ||
selector += f"&session_id={MONAILabelUtils.urllib_quote_plus(session_id)}" | ||
if output: | ||
selector += f"&output={MONAILabelUtils.urllib_quote_plus(output)}" | ||
|
||
body = self._update_client_id(body if body else {}) | ||
status, form, _, _ = MONAILabelUtils.http_method("POST", self._server_url, selector, body) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(form)}", | ||
) | ||
|
||
return None, form | ||
|
||
def train_start(self, model, params): | ||
""" | ||
Run Train Task | ||
:param model: Name of Model | ||
:param params: Additional configs/json params as part of Train request | ||
:return: json response | ||
""" | ||
params = self._update_client_id(params) | ||
|
||
selector = "/train/" | ||
if model: | ||
selector += MONAILabelUtils.urllib_quote_plus(model) | ||
|
||
status, response, _, _ = MONAILabelUtils.http_method( | ||
"POST", self._server_url, selector, params, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(response)}", | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def train_stop(self): | ||
""" | ||
Stop any running Train Task(s) | ||
:return: json response | ||
""" | ||
selector = "/train/" | ||
status, response, _, _ = MONAILabelUtils.http_method( | ||
"DELETE", self._server_url, selector, headers=self._headers | ||
) | ||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(response)}", | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
def train_status(self, check_if_running=False): | ||
""" | ||
Check Train Task Status | ||
:param check_if_running: Fast mode. Only check if training is Running | ||
:return: boolean if check_if_running is enabled; else json response that contains of full details | ||
""" | ||
selector = "/train/" | ||
if check_if_running: | ||
selector += "?check_if_running=true" | ||
status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) | ||
if check_if_running: | ||
return status == 200 | ||
|
||
if status != 200: | ||
raise MONAILabelClientException( | ||
MONAILabelError.SERVER_ERROR, | ||
f"Status: {status}; Response: {bytes_to_str(response)}", | ||
) | ||
|
||
response = bytes_to_str(response) | ||
logging.debug(f"Response: {response}") | ||
return json.loads(response) | ||
|
||
|
||
class MONAILabelError: | ||
""" | ||
Type of Inference Model | ||
Attributes: | ||
SERVER_ERROR - Server Error | ||
SESSION_EXPIRED - Session Expired | ||
UNKNOWN - Unknown Error | ||
""" | ||
|
||
SERVER_ERROR = 1 | ||
SESSION_EXPIRED = 2 | ||
UNKNOWN = 3 | ||
|
||
|
||
class MONAILabelClientException(Exception): | ||
""" | ||
MONAILabel Client Exception | ||
""" | ||
|
||
__slots__ = ["error", "msg"] | ||
|
||
def __init__(self, error, msg, status_code=None, response=None): | ||
""" | ||
:param error: Error code represented by MONAILabelError | ||
:param msg: Error message | ||
:param status_code: HTTP Response code | ||
:param response: HTTP Response | ||
""" | ||
self.error = error | ||
self.msg = msg | ||
self.status_code = status_code | ||
self.response = response | ||
|
||
|
||
class MONAILabelUtils: | ||
@staticmethod | ||
def http_method(method, server_url, selector, body=None, headers=None, content_type=None): | ||
logging.debug(f"{method} {server_url}{selector}") | ||
|
||
parsed = urlparse(server_url) | ||
path = parsed.path.rstrip("/") | ||
selector = path + "/" + selector.lstrip("/") | ||
logging.debug(f"URI Path: {selector}") | ||
|
||
parsed = urlparse(server_url) | ||
if parsed.scheme == "https": | ||
LOGGER.debug("Using HTTPS mode") | ||
# noinspection PyProtectedMember | ||
conn = http.client.HTTPSConnection(parsed.hostname, parsed.port, context=ssl._create_unverified_context()) | ||
else: | ||
conn = http.client.HTTPConnection(parsed.hostname, parsed.port) | ||
|
||
headers = headers if headers else {} | ||
if body: | ||
if not content_type: | ||
if isinstance(body, dict): | ||
body = json.dumps(body) | ||
content_type = "application/json" | ||
else: | ||
content_type = "text/plain" | ||
headers.update({"content-type": content_type, "content-length": str(len(body))}) | ||
|
||
conn.request(method, selector, body=body, headers=headers) | ||
return MONAILabelUtils.send_response(conn) | ||
|
||
@staticmethod | ||
def http_upload(method, server_url, selector, fields, files, headers=None): | ||
logging.debug(f"{method} {server_url}{selector}") | ||
|
||
url = server_url.rstrip("/") + "/" + selector.lstrip("/") | ||
logging.debug(f"URL: {url}") | ||
|
||
files = [("files", (os.path.basename(f), open(f, "rb"))) for f in files] | ||
headers = headers if headers else {} | ||
response = ( | ||
requests.post(url, files=files, headers=headers) | ||
if method == "POST" | ||
else requests.put(url, files=files, data=fields, headers=headers) | ||
) | ||
return response.status_code, response.text, None | ||
|
||
@staticmethod | ||
def http_multipart(method, server_url, selector, fields, files, headers={}): | ||
logging.debug(f"{method} {server_url}{selector}") | ||
|
||
content_type, body = MONAILabelUtils.encode_multipart_formdata(fields, files) | ||
headers = headers if headers else {} | ||
headers.update({"content-type": content_type, "content-length": str(len(body))}) | ||
|
||
parsed = urlparse(server_url) | ||
path = parsed.path.rstrip("/") | ||
selector = path + "/" + selector.lstrip("/") | ||
logging.debug(f"URI Path: {selector}") | ||
|
||
if parsed.scheme == "https": | ||
LOGGER.debug("Using HTTPS mode") | ||
# noinspection PyProtectedMember | ||
conn = http.client.HTTPSConnection(parsed.hostname, parsed.port, context=ssl._create_unverified_context()) | ||
else: | ||
conn = http.client.HTTPConnection(parsed.hostname, parsed.port) | ||
|
||
conn.request(method, selector, body, headers) | ||
return MONAILabelUtils.send_response(conn, content_type) | ||
|
||
@staticmethod | ||
def send_response(conn, content_type="application/json"): | ||
response = conn.getresponse() | ||
logging.debug(f"HTTP Response Code: {response.status}") | ||
logging.debug(f"HTTP Response Message: {response.reason}") | ||
logging.debug(f"HTTP Response Headers: {response.getheaders()}") | ||
|
||
response_content_type = response.getheader("content-type", content_type) | ||
logging.debug(f"HTTP Response Content-Type: {response_content_type}") | ||
|
||
if "multipart" in response_content_type: | ||
if response.status == 200: | ||
form, files = MONAILabelUtils.parse_multipart(response.fp if response.fp else response, response.msg) | ||
logging.debug(f"Response FORM: {form}") | ||
logging.debug(f"Response FILES: {files.keys()}") | ||
return response.status, form, files, response.headers | ||
else: | ||
return response.status, response.read(), None, response.headers | ||
|
||
logging.debug("Reading status/content from simple response!") | ||
return response.status, response.read(), None, response.headers | ||
|
||
@staticmethod | ||
def save_result(files, tmpdir): | ||
for name in files: | ||
data = files[name] | ||
result_file = os.path.join(tmpdir, name) | ||
|
||
logging.debug(f"Saving {name} to {result_file}; Size: {len(data)}") | ||
dir_path = os.path.dirname(os.path.realpath(result_file)) | ||
if not os.path.exists(dir_path): | ||
os.makedirs(dir_path) | ||
|
||
with open(result_file, "wb") as f: | ||
if isinstance(data, bytes): | ||
f.write(data) | ||
else: | ||
f.write(data.encode("utf-8")) | ||
|
||
# Currently only one file per response supported | ||
return result_file | ||
|
||
@staticmethod | ||
def encode_multipart_formdata(fields, files): | ||
limit = "----------lImIt_of_THE_fIle_eW_$" | ||
lines = [] | ||
for key, value in fields.items(): | ||
lines.append("--" + limit) | ||
lines.append('Content-Disposition: form-data; name="%s"' % key) | ||
lines.append("") | ||
lines.append(value) | ||
for key, filename in files.items(): | ||
lines.append("--" + limit) | ||
lines.append(f'Content-Disposition: form-data; name="{key}"; filename="{filename}"') | ||
lines.append("Content-Type: %s" % MONAILabelUtils.get_content_type(filename)) | ||
lines.append("") | ||
with open(filename, mode="rb") as f: | ||
data = f.read() | ||
lines.append(data) | ||
lines.append("--" + limit + "--") | ||
lines.append("") | ||
|
||
body = bytearray() | ||
for line in lines: | ||
body.extend(line if isinstance(line, bytes) else line.encode("utf-8")) | ||
body.extend(b"\r\n") | ||
|
||
content_type = "multipart/form-data; boundary=%s" % limit | ||
return content_type, body | ||
|
||
@staticmethod | ||
def get_content_type(filename): | ||
return mimetypes.guess_type(filename)[0] or "application/octet-stream" | ||
|
||
@staticmethod | ||
def parse_multipart(fp, headers): | ||
fs = cgi.FieldStorage( | ||
fp=fp, | ||
environ={"REQUEST_METHOD": "POST"}, | ||
headers=headers, | ||
keep_blank_values=True, | ||
) | ||
form = {} | ||
files = {} | ||
if hasattr(fs, "list") and isinstance(fs.list, list): | ||
for f in fs.list: | ||
LOGGER.debug(f"FILE-NAME: {f.filename}; NAME: {f.name}; SIZE: {len(f.value)}") | ||
if f.filename: | ||
files[f.filename] = f.value | ||
else: | ||
form[f.name] = f.value | ||
return form, files | ||
|
||
@staticmethod | ||
def urllib_quote_plus(s): | ||
return quote_plus(s) | ||
|
||
@staticmethod | ||
def get_filename(content_disposition): | ||
file_name = re.findall(r"filename\*=([^;]+)", content_disposition, flags=re.IGNORECASE) | ||
if not file_name: | ||
file_name = re.findall('filename="(.+)"', content_disposition, flags=re.IGNORECASE) | ||
if "utf-8''" in file_name[0].lower(): | ||
file_name = re.sub("utf-8''", "", file_name[0], flags=re.IGNORECASE) | ||
file_name = unquote(file_name) | ||
else: | ||
file_name = file_name[0] | ||
return file_name | ||
|
||
|
||
class RunVISTA2D(ImageSegmentation): | ||
category = "Object Processing" | ||
|
||
module_name = "RunVISTA2D" | ||
|
||
variable_revision_number = 1 | ||
|
||
doi = { | ||
"Please cite the following when using RunVISTA2D:": "https://doi.org/10.48550/arXiv.2406.05285", | ||
} | ||
|
||
def create_settings(self): | ||
super().create_settings() | ||
|
||
self.server_address = Text( | ||
text="MONAI label server address", | ||
value="http://127.0.0.1:8000", | ||
doc="""\ | ||
Please set up the MONAI label server in local/cloud environment and fill the server address here. | ||
""", | ||
) | ||
|
||
self.model_name = Choice( | ||
text="The model for running the inference", | ||
choices=["vista2d"], | ||
value="vista2d", | ||
doc=""" | ||
Pick the model for running infernce. Now only VISTA2D is available. | ||
""", | ||
) | ||
|
||
def settings(self): | ||
return [ | ||
self.x_name, | ||
self.y_name, | ||
self.server_address, | ||
self.model_name, | ||
] | ||
|
||
def visible_settings(self): | ||
return [ | ||
self.x_name, | ||
self.y_name, | ||
self.server_address, | ||
self.model_name, | ||
] | ||
|
||
def run(self, workspace): | ||
x_name = self.x_name.value | ||
y_name = self.y_name.value | ||
images = workspace.image_set | ||
x = images.get_image(x_name) | ||
dimensions = x.dimensions | ||
x_data = x.pixel_data | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
temp_img_dir = os.path.join(temp_dir, "img") | ||
os.makedirs(temp_img_dir, exist_ok=True) | ||
temp_img_path = os.path.join(temp_img_dir, x_name + ".tiff") | ||
temp_mask_dir = os.path.join(temp_dir, "mask") | ||
os.makedirs(temp_mask_dir, exist_ok=True) | ||
skimage.io.imsave(temp_img_path, x_data) | ||
monailabel_client = MONAILabelClient(server_url=self.server_address.value, tmpdir=temp_mask_dir) | ||
image_out, params = monailabel_client.infer( | ||
model=self.model_name.value, image_id="", params={}, file=temp_img_path | ||
) | ||
print(f"Image out:\n{image_out}") | ||
print(f"Params:\n{params}") | ||
y_data = skimage.io.imread(image_out) | ||
|
||
y = Objects() | ||
y.segmented = y_data | ||
y.parent_image = x.parent_image | ||
objects = workspace.object_set | ||
objects.add_objects(y, y_name) | ||
|
||
self.add_measurements(workspace) | ||
|
||
if self.show_window: | ||
workspace.display_data.x_data = x_data | ||
workspace.display_data.y_data = y_data | ||
workspace.display_data.dimensions = dimensions | ||
|
||
def display(self, workspace, figure): | ||
layout = (2, 1) | ||
figure.set_subplots(dimensions=workspace.display_data.dimensions, subplots=layout) | ||
|
||
figure.subplot_imshow( | ||
colormap="gray", | ||
image=workspace.display_data.x_data, | ||
title="Input Image", | ||
x=0, | ||
y=0, | ||
) | ||
|
||
figure.subplot_imshow_labels( | ||
image=workspace.display_data.y_data, | ||
sharexy=figure.subplot(0, 0), | ||
title=self.y_name.value, | ||
x=1, | ||
y=0, | ||
) | ||
|
||
# def upgrade_settings(self, setting_values, variable_revision_number, module_name): | ||
# ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
|
||
import cellprofiler_core.image | ||
import cellprofiler_core.measurement | ||
import cellprofiler_core.object | ||
import cellprofiler_core.pipeline | ||
import cellprofiler_core.setting | ||
import cellprofiler_core.workspace | ||
import numpy | ||
import pytest | ||
from active_plugins.runvista2d import MONAILabelClient, MONAILabelClientException, MONAILabelUtils, RunVISTA2D | ||
|
||
IMAGE_NAME = "my_image" | ||
OBJECTS_NAME = "my_objects" | ||
MODEL_NAME = "vista2d" | ||
SERVER_ADDRESS = "http://127.0.0.1:8000" | ||
|
||
|
||
class MockResponse: | ||
@staticmethod | ||
def infer(*args, **kwargs): | ||
filepath = os.path.abspath(__file__) | ||
dir = os.path.dirname(filepath) | ||
image = os.path.join(dir, "resources", "vista2d_test.tiff") | ||
return image, {} | ||
|
||
|
||
class MockErrResponse: | ||
@staticmethod | ||
def http_multipart(*args, **kwargs): | ||
return 400, {}, {}, {} | ||
|
||
|
||
def test_mock_failed(): | ||
x = RunVISTA2D() | ||
x.y_name.value = OBJECTS_NAME | ||
x.x_name.value = IMAGE_NAME | ||
x.server_address.value = SERVER_ADDRESS | ||
x.model_name.value = MODEL_NAME | ||
|
||
img = numpy.zeros((128, 128, 3)) | ||
image = cellprofiler_core.image.Image(img) | ||
image_set_list = cellprofiler_core.image.ImageSetList() | ||
image_set = image_set_list.get_image_set(0) | ||
image_set.providers.append(cellprofiler_core.image.VanillaImage(IMAGE_NAME, image)) | ||
object_set = cellprofiler_core.object.ObjectSet() | ||
measurements = cellprofiler_core.measurement.Measurements() | ||
pipeline = cellprofiler_core.pipeline.Pipeline() | ||
|
||
pytest.MonkeyPatch().setattr(MONAILabelUtils, "http_multipart", MockErrResponse.http_multipart) | ||
with pytest.raises(MONAILabelClientException): | ||
x.run(cellprofiler_core.workspace.Workspace(pipeline, x, image_set, object_set, measurements, None)) | ||
|
||
|
||
def test_mock_successful(): | ||
x = RunVISTA2D() | ||
x.y_name.value = OBJECTS_NAME | ||
x.x_name.value = IMAGE_NAME | ||
x.server_address.value = SERVER_ADDRESS | ||
x.model_name.value = MODEL_NAME | ||
|
||
img = numpy.zeros((128, 128, 3)) | ||
image = cellprofiler_core.image.Image(img) | ||
image_set_list = cellprofiler_core.image.ImageSetList() | ||
image_set = image_set_list.get_image_set(0) | ||
image_set.providers.append(cellprofiler_core.image.VanillaImage(IMAGE_NAME, image)) | ||
object_set = cellprofiler_core.object.ObjectSet() | ||
measurements = cellprofiler_core.measurement.Measurements() | ||
pipeline = cellprofiler_core.pipeline.Pipeline() | ||
|
||
pytest.MonkeyPatch().setattr(MONAILabelClient, "infer", MockResponse.infer) | ||
x.run(cellprofiler_core.workspace.Workspace(pipeline, x, image_set, object_set, measurements, None)) | ||
assert len(object_set.object_names) == 1 | ||
assert OBJECTS_NAME in object_set.object_names | ||
objects = object_set.get_objects(OBJECTS_NAME) | ||
segmented = objects.segmented | ||
assert numpy.all(segmented == 0) | ||
assert "Image" in measurements.get_object_names() | ||
assert OBJECTS_NAME in measurements.get_object_names() | ||
|
||
assert f"Count_{OBJECTS_NAME}" in measurements.get_feature_names("Image") | ||
count = measurements.get_current_measurement("Image", f"Count_{OBJECTS_NAME}") | ||
assert count == 0 | ||
assert "Location_Center_X" in measurements.get_feature_names(OBJECTS_NAME) | ||
location_center_x = measurements.get_current_measurement(OBJECTS_NAME, "Location_Center_X") | ||
assert isinstance(location_center_x, numpy.ndarray) | ||
assert numpy.product(location_center_x.shape) == 0 | ||
assert "Location_Center_Y" in measurements.get_feature_names(OBJECTS_NAME) | ||
location_center_y = measurements.get_current_measurement(OBJECTS_NAME, "Location_Center_Y") | ||
assert isinstance(location_center_y, numpy.ndarray) | ||
assert numpy.product(location_center_y.shape) == 0 |