Skip to content

Commit

Permalink
Merge pull request #13 from Oemercetin06/mlst
Browse files Browse the repository at this point in the history
Added MLST feature
  • Loading branch information
aromberg authored Jan 20, 2025
2 parents faf40eb + 4a263a9 commit 93d1206
Show file tree
Hide file tree
Showing 8 changed files with 708 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/xspect/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,9 @@ def get_xspect_runs_path():
runs_path = get_xspect_root_path() / "runs"
runs_path.mkdir(exist_ok=True, parents=True)
return runs_path

def get_xspect_mlst_path():
"""Return the path to the XspecT runs directory."""
mlst_path = get_xspect_root_path() / "mlst"
mlst_path.mkdir(exist_ok=True, parents=True)
return mlst_path
53 changes: 51 additions & 2 deletions src/xspect/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
from xspect.models.result import (
StepType,
)
from xspect.definitions import get_xspect_runs_path, fasta_endings, fastq_endings
from xspect.definitions import (
get_xspect_runs_path,
fasta_endings,
fastq_endings,
get_xspect_model_path
)
from xspect.pipeline import ModelExecution, Pipeline, PipelineStep
from src.xspect.mlst_feature.mlst_helper import pick_scheme, pick_scheme_from_models_dir
from src.xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
from src.xspect.models.probabilistic_filter_mlst_model import ProbabilisticFilterMlstSchemeModel


@click.group()
Expand Down Expand Up @@ -117,12 +125,53 @@ def train(genus, bf_assembly_path, svm_assembly_path, svm_step):
except ValueError as e:
raise click.ClickException(str(e)) from e

@cli.command()
@click.option(
"-c",
"--choose_schemes",
is_flag=True,
help="Choose your own schemes."
"Default setting is Oxford and Pasteur scheme of A.baumannii.",
)
def mlst_train(choose_schemes):
"""Download alleles and train bloom filters."""
click.echo("Updating alleles")
handler = PubMLSTHandler()
handler.download_alleles(choose_schemes)
click.echo("Download finished")
scheme_path = pick_scheme(handler.get_scheme_paths())
species_name = str(scheme_path).split("/")[-2]
scheme_name = str(scheme_path).split("/")[-1]
model = ProbabilisticFilterMlstSchemeModel(
31,
f"{species_name}:{scheme_name}",
get_xspect_model_path()
)
click.echo("Creating mlst model")
model.fit(scheme_path)
model.save()
click.echo(f"Saved at {model.cobs_path}")

@cli.command()
@click.option(
"-p",
"--path",
help="Path to FASTA-file for mlst identification.",
type=click.Path(exists=True, dir_okay=True, file_okay=True)
)
def mlst_classify(path):
"""Download alleles and train bloom filters."""
click.echo("Classifying...")
path = Path(path)
scheme_path = pick_scheme_from_models_dir()
model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
model.predict(scheme_path, path).save(model.model_display_name, path)
click.echo(f"Run saved at {get_xspect_runs_path()}.")

@cli.command()
def api():
"""Open the XspecT FastAPI."""
uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
cli()
Empty file.
136 changes: 136 additions & 0 deletions src/xspect/mlst_feature/mlst_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
""" Module for utility functions used in other modules regarding MLST. """

__author__ = "Cetin, Oemer"

import requests
import json
from io import StringIO
from pathlib import Path
from Bio import SeqIO
from src.xspect.definitions import get_xspect_model_path, get_xspect_runs_path

def create_fasta_files(locus_path:Path, fasta_batch:str):
"""Create Fasta-Files for every allele of a locus."""
# fasta_batch = full string of a fasta file containing every allele sequence of a locus
for record in SeqIO.parse(StringIO(fasta_batch), "fasta"):
number = record.id.split("_")[-1] # example id = Oxf_cpn60_263
output_fasta_file = locus_path / f"Allele_ID_{number}.fasta"
if output_fasta_file.exists(): continue # Ignore existing ones
with (open(output_fasta_file, "w") as allele):
SeqIO.write(record, allele, "fasta")

def pick_species_number_from_db(available_species:dict) -> str:
"""Returns the chosen species from all available ones in the database."""
# The "database" string can look like this: pubmlst_abaumannii_seqdef
for counter, database in available_species.items():
print(str(counter) + ":" + database.split("_")[1])
print("\nPick one of the above databases")
while True:
try:
choice = input("Choose a species by selecting the corresponding number:")
if int(choice) in available_species.keys():
chosen_species = available_species.get(int(choice))
return chosen_species
else:
print("Wrong input! Try again with a number that is available in the list above.")
except ValueError:
print("Wrong input! Try again with a number that is available in the list above.")

def pick_scheme_number_from_db(available_schemes:dict) -> str:
"""Returns the chosen schemes from all available ones of a species."""
# List all available schemes of a species database
for counter, scheme in available_schemes.items():
print(str(counter) + ":" + scheme[0])
print("\nPick any available scheme that is listed for download")
while True:
try:
choice = input("Choose a scheme by selecting the corresponding number:")
if int(choice) in available_schemes.keys():
chosen_scheme = available_schemes.get(int(choice))[1]
return chosen_scheme
else:
print("Wrong input! Try again with a number that is available in the above list.")
except ValueError:
print("Wrong input! Try again with a number that is available in the above list.")

def scheme_list_to_dict(scheme_list:list[str]):
"""Converts the scheme list attribute into a dictionary with a number as the key."""
return dict(zip(range(1, len(scheme_list) + 1), scheme_list))

def pick_scheme_from_models_dir() -> Path:
"""Returns the chosen scheme from models that have been fitted prior."""
schemes = {}; counter = 1
for entry in sorted((get_xspect_model_path() / "MLST").iterdir()):
schemes[counter] = entry
counter += 1
return pick_scheme(schemes)

def pick_scheme(available_schemes:dict) -> Path:
"""Returns the chosen scheme from the scheme list."""
if not available_schemes:
raise ValueError("No scheme has been chosen for download yet!")

if len(available_schemes.items()) == 1:
return next(iter(available_schemes.values()))

# List available schemes
for counter, scheme in available_schemes.items():
# For Strain Typing with an API-POST Request to the db
if str(scheme).startswith("http"):
scheme_json = requests.get(scheme).json()
print(str(counter) + ":" + scheme_json["description"])

# To pick a scheme after download for fitting
else:
print(str(counter) + ":" + str(scheme).split("/")[-1])

print("\nPick a scheme for strain type prediction")
while True:
try:
choice = input("Choose a scheme by selecting the corresponding number:")
if int(choice) in available_schemes.keys():
chosen_scheme = available_schemes.get(int(choice))
return chosen_scheme
else:
print("Wrong input! Try again with a number that is available in the above list.")
except ValueError:
print("Wrong input! Try again with a number that is available in the above list.")

class MlstResult:
"""Class for storing mlst results."""
def __init__(
self,
scheme_model:str,
steps:int,
hits: dict[str,list[dict]],
):
self.scheme_model = scheme_model
self.steps = steps
self.hits = hits

def get_results(self) -> dict:
"""Stores the result of a prediction in a dictionary."""
results = {
seq_id: result
for seq_id, result in self.hits.items()
}
return results

def to_dict(self) -> dict:
"""Converts all attributes into one dictionary."""
result = {
"Scheme":self.scheme_model,
"Steps":self.steps,
"Results": self.get_results()
}
return result

def save(self, display:str, file_path:Path) -> None:
"""Saves the result inside the "runs" directory"""
file_name = str(file_path).split("/")[-1]
json_path = get_xspect_runs_path() / "MLST" / f"{file_name}-{display}.json"
json_path.parent.mkdir(exist_ok=True, parents=True)
json_object = json.dumps(self.to_dict(), indent=4)

with open(json_path, "w", encoding="utf-8") as file:
file.write(json_object)
107 changes: 107 additions & 0 deletions src/xspect/mlst_feature/pub_mlst_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Module for connecting with the PubMLST database via API requests and downloading allele files."""

__author__ = "Cetin, Oemer"

import requests
import json
from src.xspect.mlst_feature.mlst_helper import (
create_fasta_files,
pick_species_number_from_db,
pick_scheme_number_from_db,
pick_scheme,
scheme_list_to_dict
)
from src.xspect.definitions import (
get_xspect_mlst_path,
get_xspect_upload_path
)

class PubMLSTHandler:
"""Class for communicating with PubMLST and downloading alleles (FASTA-Format) from all loci."""
base_url = "http://rest.pubmlst.org/db"

def __init__(self):
# Default values: Oxford (1) and Pasteur (2) schemes of A.baumannii species
self.scheme_list = [
self.base_url + "/pubmlst_abaumannii_seqdef/schemes/1",
self.base_url + "/pubmlst_abaumannii_seqdef/schemes/2"
]
self.scheme_paths = []

def get_scheme_paths(self) -> dict:
"""Returns the scheme paths in a dictionary"""
return scheme_list_to_dict(self.scheme_paths)

def choose_schemes(self) -> None:
"""Changes the scheme list attribute to feature other schemes from some species"""
available_species = {}; available_schemes = {}; chosen_schemes = []; counter = 1
# retrieve all available species
species_url = PubMLSTHandler.base_url
for species_databases in requests.get(species_url).json():
for database in species_databases["databases"]:
if database["name"].endswith("seqdef"):
available_species[counter] = database["name"]
counter += 1
# pick a species out of the available ones
chosen_species = pick_species_number_from_db(available_species)

counter = 1
scheme_url = f"{species_url}/{chosen_species}/schemes"
for scheme in requests.get(scheme_url).json()["schemes"]:
# scheme["description"] stores the name of a scheme.
# scheme["scheme"] stores the URL that is needed for downloading all loci.
available_schemes[counter] = [scheme["description"], scheme["scheme"]]
counter += 1

# Selection process of available scheme from a species for download (doubles are caught!)
while True:
chosen_scheme = pick_scheme_number_from_db(available_schemes)
chosen_schemes.append(chosen_scheme) if chosen_scheme not in chosen_schemes else None
choice = input("Do you want to pick another scheme to download? (y/n):").lower()
if choice != "y":
break
self.scheme_list = chosen_schemes

def download_alleles(self, choice:False):
"""Downloads every allele FASTA-file from all loci of the scheme list attribute"""
if choice: # pick an own scheme if not Oxford or Pasteur
self.choose_schemes() # changes the scheme_list attribute

for scheme in self.scheme_list:
scheme_json = requests.get(scheme).json()
# We only want the name and the respective featured loci of a scheme
scheme_name = scheme_json["description"]
locus_list = scheme_json["loci"]

species_name = scheme.split("_")[1] # name = pubmlst_abaumannii_seqdef
scheme_path = get_xspect_mlst_path() / species_name / scheme_name
self.scheme_paths.append(scheme_path)

for locus_url in locus_list:
# After using split the last part ([-1]) of the url is the locus name
locus_name = locus_url.split("/")[-1]
locus_path = get_xspect_mlst_path() / species_name / scheme_name / locus_name

if not locus_path.exists():
locus_path.mkdir(exist_ok=True, parents=True)

alleles = requests.get(f"{locus_url}/alleles_fasta").text
create_fasta_files(locus_path,alleles)

def assign_strain_type_by_db(self):
"""Sends an API-POST-Request to the database for MLST without bloom filters"""
scheme_url = str(pick_scheme(scheme_list_to_dict(self.scheme_list))) + "/sequence"
fasta_file = get_xspect_upload_path() / "Test.fna"
with open(fasta_file, 'r') as file:
data = file.read()
payload = { # Essential API-POST-Body
"sequence": data,
"filetype": "fasta",
}
response = requests.post(scheme_url, data=json.dumps(payload)).json()

for locus, meta_data in response["exact_matches"].items():
# meta_data is a list containing a dictionary, therefore [0] and then key value.
# Example: 'Pas_fusA': [{'href': some URL, 'allele_id': '2'}]
print(locus + ":" + meta_data[0]["allele_id"], end= "; ")
print("\nStrain Type:", response["fields"])
Loading

0 comments on commit 93d1206

Please sign in to comment.