Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt parallelized demultiplexing #200

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions harpy/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def validate_demuxschema(infile):
try:
sample, segment_id = line.rstrip().split()
if not segment_pattern.match(segment_id):
print_error("invalid segment format", f"Segment ID '{segment_id}' does not follow the expected format.")
print_error("invalid segment format", f"Segment ID [green]{segment_id}[/green] does not follow the expected format.")
print_solution("This haplotagging design expects segments to follow the format of letter [green bold]A-D[/green bold] followed by [bold]two digits[/bold], e.g. [green bold]C51[/green bold]). Check that your ID segments or formatted correctly and that you are attempting to demultiplex with a workflow appropriate for your data design.")
sys.exit(1)
code_letters.add(segment_id[0])
Expand All @@ -283,10 +283,15 @@ def validate_demuxschema(infile):
# skip rows without two columns
continue
if not code_letters:
print_error("incorrect schema format", f"Schema file {os.path.basename(infile)} has no valid rows. Rows should be sample<tab>segment, e.g. sample_01<tab>C75")
print_error("incorrect schema format", f"Schema file [blue]{os.path.basename(infile)}[/blue] has no valid rows. Rows should be sample<tab>segment, e.g. sample_01<tab>C75")
sys.exit(1)
if len(code_letters) > 1:
print("invalid schema", f"Schema file {os.path.basename(infile)} has sample IDs expected to be identified across multiple barcode segments. All sample IDs for this technology should be in a single segment, such as [bold green]C[/bold green] or [bold green]D[/bold green].")
print_error("invalid schema", f"Schema file [blue]{os.path.basename(infile)}[/blue] has sample IDs occurring in different barcode segments.")
print_solution_with_culprits(
"All sample IDs for this barcode design should be in a single segment, such as [bold green]C[/bold green] or [bold green]D[/bold green]. Make sure the schema contains only one segment.",
"The segments identified in the schema:"
)
click.echo(", ".join(code_letters))
sys.exit(1)

def validate_regions(regioninput, genome):
Expand Down
2 changes: 1 addition & 1 deletion harpy/bin/bx_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def writestats(x, writechrom, destination):
all_bx = set()
LAST_CONTIG = None

for read in alnfile.fetch():
for read in alnfile.fetch(until_eof=True):
chrom = read.reference_name
# check if the current chromosome is different from the previous one
# if so, print the dict to file and empty it (a consideration for RAM usage)
Expand Down
8 changes: 4 additions & 4 deletions harpy/bin/check_bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
parser.error(f"{args.input} was not found")

bam_in = args.input
if bam_in.lower().endswith(".bam"):
if not os.path.exists(bam_in + ".bai"):
pysam.index(bam_in)
#if bam_in.lower().endswith(".bam"):
# if not os.path.exists(bam_in + ".bai"):
# pysam.index(bam_in)

# regex for EXACTLY AXXCXXBXXDXX
haplotag = re.compile('^A[0-9][0-9]C[0-9][0-9]B[0-9][0-9]D[0-9][0-9]')
Expand All @@ -52,7 +52,7 @@
BX_NOT_LAST = 0
NO_MI = 0

for record in alnfile.fetch():
for record in alnfile.fetch(until_eof=True):
N_READS += 1
tags = [i[0] for i in record.get_tags()]
# is there a bx tag?
Expand Down
35 changes: 17 additions & 18 deletions harpy/bin/concatenate_bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@
haplotag_limit = 96**4

# instantiate output file
if aln_list[0].lower().endswith(".bam"):
if not os.path.exists(f"{aln_list[0]}.bai"):
pysam.index(aln_list[0])
# for housekeeping
DELETE_FIRST_INDEX = True
else:
DELETE_FIRST_INDEX = False
#if aln_list[0].lower().endswith(".bam"):
# if not os.path.exists(f"{aln_list[0]}.bai"):
# pysam.index(aln_list[0])
# # for housekeeping
# DELETE_FIRST_INDEX = True
# else:
# DELETE_FIRST_INDEX = False
with pysam.AlignmentFile(aln_list[0]) as xam_in:
header = xam_in.header.to_dict()
# Remove all @PG lines
Expand Down Expand Up @@ -99,14 +99,14 @@
# create MI dict for this sample
MI_LOCAL = {}
# create index if it doesn't exist
if xam.lower().endswith(".bam"):
if not os.path.exists(f"{xam}.bai"):
pysam.index(xam)
DELETE_INDEX = True
else:
DELETE_INDEX = False
#if xam.lower().endswith(".bam"):
# if not os.path.exists(f"{xam}.bai"):
# pysam.index(xam)
# DELETE_INDEX = True
# else:
# DELETE_INDEX = False
with pysam.AlignmentFile(xam) as xamfile:
for record in xamfile.fetch():
for record in xamfile.fetch(until_eof=True):
try:
mi = record.get_tag("MI")
# if previously converted for this sample, use that
Expand Down Expand Up @@ -136,8 +136,7 @@
except KeyError:
pass
bam_out.write(record)
if DELETE_INDEX:
Path.unlink(f"{xam}.bai")

# just for consistent housekeeping
if DELETE_FIRST_INDEX:
Path.unlink(f"{aln_list[0]}.bai")
#if DELETE_FIRST_INDEX:
# Path.unlink(f"{aln_list[0]}.bai")
14 changes: 7 additions & 7 deletions harpy/bin/deconvolve_alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,18 @@ def write_missingbx(bam, alnrecord):
# MI is the name of the current molecule, starting a 1 (0+1)
MI = 0

if bam_input.lower().endswith(".bam") and not os.path.exists(bam_input + ".bai"):
try:
pysam.index(bam_input)
except (OSError, pysam.SamtoolsError) as e:
sys.stderr.write(f"Error indexing BAM file: {e}\n")
sys.exit(1)
#if bam_input.lower().endswith(".bam") and not os.path.exists(bam_input + ".bai"):
# try:
# pysam.index(bam_input)
# except (OSError, pysam.SamtoolsError) as e:
# sys.stderr.write(f"Error indexing BAM file: {e}\n")
# sys.exit(1)

with (
pysam.AlignmentFile(bam_input) as alnfile,
pysam.AlignmentFile(sys.stdout.buffer, "wb", template = alnfile) as outfile
):
for record in alnfile.fetch():
for record in alnfile.fetch(until_eof=True):
chrm = record.reference_name
bp = record.query_alignment_length
# check if the current chromosome is different from the previous one
Expand Down
6 changes: 3 additions & 3 deletions harpy/bin/leviathan_bx_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
args = parser.parse_args()
if not os.path.exists(args.input):
parser.error(f"{args.input} was not found")
if not os.path.exists(args.input + ".bai"):
parser.error(f"{args.input}.bai was not found")
#if not os.path.exists(args.input + ".bai"):
# parser.error(f"{args.input}.bai was not found")

# set up a generator for the BX tags
bc_range = [f"{i}".zfill(2) for i in range(1,97)]
Expand All @@ -37,7 +37,7 @@

with pysam.AlignmentFile(args.input) as bam_in, pysam.AlignmentFile(sys.stdout.buffer, 'wb', header=bam_in.header) as bam_out:
# iterate through the bam file
for record in bam_in.fetch():
for record in bam_in.fetch(until_eof=True):
try:
mi = record.get_tag("MI")
if mi not in MI_BX:
Expand Down
12 changes: 7 additions & 5 deletions harpy/demultiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def demultiplex():
"harpy demultiplex gen1": [
{
"name": "Parameters",
"options": ["--qx-rx","--schema"],
"options": ["--keep-unknown", "--qx-rx","--schema"],
},
{
"name": "Workflow Controls",
Expand All @@ -38,10 +38,11 @@ def demultiplex():
}

@click.command(no_args_is_help = True, context_settings=dict(allow_interspersed_args=False), epilog = "Documentation: https://pdimens.github.io/harpy/workflows/demultiplex/")
@click.option('-u', '--keep-unknown', is_flag = True, default = False, help = 'Keep reads that could not be demultiplexed')
@click.option('-q', '--qx-rx', is_flag = True, default = False, help = 'Include the `QX:Z` and `RX:Z` tags in the read header')
@click.option('-s', '--schema', required = True, type=click.Path(exists=True, dir_okay=False, readable=True), help = 'File of `sample`\\<TAB\\>`barcode`')
@click.option('-t', '--threads', default = 4, show_default = True, type = click.IntRange(min = 1, max_open = True), help = 'Number of threads to use')
@click.option('-t', '--threads', default = 4, show_default = True, type = click.IntRange(min = 2, max_open = True), help = 'Number of threads to use')
@click.option('-o', '--output-dir', type = click.Path(exists = False), default = "Demultiplex", show_default=True, help = 'Output directory name')
@click.option('-q', '--qx-rx', is_flag = True, default = False, help = 'Include the `QX:Z` and `RX:Z` tags in the read header')
@click.option('--container', is_flag = True, default = False, help = 'Use a container instead of conda')
@click.option('--setup-only', is_flag = True, hidden = True, default = False, help = 'Setup the workflow and exit')
@click.option('--hpc', type = HPCProfile(), help = 'Directory with HPC submission `config.yaml` file')
Expand All @@ -52,7 +53,7 @@ def demultiplex():
@click.argument('R2_FQ', required=True, type=click.Path(exists=True, dir_okay=False, readable=True))
@click.argument('I1_FQ', required=True, type=click.Path(exists=True, dir_okay=False, readable=True))
@click.argument('I2_FQ', required=True, type=click.Path(exists=True, dir_okay=False, readable=True))
def gen1(r1_fq, r2_fq, i1_fq, i2_fq, output_dir, schema, qx_rx, threads, snakemake, skip_reports, quiet, hpc, container, setup_only):
def gen1(r1_fq, r2_fq, i1_fq, i2_fq, output_dir, keep_unknown, schema, qx_rx, threads, snakemake, skip_reports, quiet, hpc, container, setup_only):
"""
Demultiplex Generation I haplotagged FASTQ files

Expand Down Expand Up @@ -84,12 +85,13 @@ def gen1(r1_fq, r2_fq, i1_fq, i2_fq, output_dir, schema, qx_rx, threads, snakema
"workflow" : "demultiplex gen1",
"snakemake_log" : sm_log,
"output_directory" : output_dir,
"include_qx_rx_tags" : qx_rx,
"keep_unknown" : keep_unknown,
"workflow_call" : command.rstrip(),
"conda_environments" : conda_envs,
"reports" : {
"skip": skip_reports
},
"include_qx_rx_tags" : qx_rx,
"inputs" : {
"demultiplex_schema" : Path(schema).resolve().as_posix(),
"R1": Path(r1_fq).resolve().as_posix(),
Expand Down
70 changes: 21 additions & 49 deletions harpy/scripts/demultiplex_gen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,6 @@ def read_barcodes(file_path, segment):
continue
return data_dict

def read_schema(file_path):
"""Read and parse schema file of sample<tab>id_segment"""
# one sample can have more than one code
# {segment : sample}
data_dict = {}
# codes can be Axx, Bxx, Cxx, Dxx
code_letters = set()
with open(file_path, 'r') as file:
for line in file:
try:
sample, segment_id = line.rstrip().split()
data_dict[segment_id] = sample
code_letters.add(segment_id[0])
except ValueError:
# skip rows without two columns
continue
id_letter = code_letters.pop()
return id_letter, data_dict

def get_min_dist(needle, code_letter):
minDist = 999
nbFound = 0
Expand Down Expand Up @@ -81,7 +62,9 @@ def get_read_codes(index_read, left_segment, right_segment):
sys.stderr = sys.stdout = f
outdir = snakemake.params.outdir
qxrx = snakemake.params.qxrx
schema = snakemake.input.schema
sample_name = snakemake.params.sample
id_segments = snakemake.params.id_segments
id_letter = id_segments[0][0]
r1 = snakemake.input.R1
r2 = snakemake.input.R2
i1 = snakemake.input.I1
Expand All @@ -97,14 +80,6 @@ def get_read_codes(index_read, left_segment, right_segment):
"D" : read_barcodes(bx_d, "D"),
}

#read schema
id_letter, samples_dict = read_schema(schema)
samples = list(set(samples_dict.values()))
samples.append("unidentified_data")
#create an array of files (one per sample) for writing
R1_output = {sample: open(f"{outdir}/{sample}.R1.fq", 'w') for sample in samples}
R2_output = {sample: open(f"{outdir}/{sample}.R2.fq", 'w') for sample in samples}

segments = {'A':'', 'B':'', 'C':'', 'D':''}
unclear_read_map={}
clear_read_map={}
Expand All @@ -113,49 +88,46 @@ def get_read_codes(index_read, left_segment, right_segment):
pysam.FastxFile(r2) as R2,
pysam.FastxFile(i1, persist = False) as I1,
pysam.FastxFile(i2, persist = False) as I2,
open(snakemake.output.valid, 'w') as clearBC_log,
open(snakemake.output.invalid, 'w') as unclearBC_log
open(f"{outdir}/{sample_name}.R1.fq", 'w') as R1_out,
open(f"{outdir}/{sample_name}.R2.fq", 'w') as R2_out,
open(snakemake.output.bx_info, 'w') as BC_log
):
for r1_rec, r2_rec, i1_rec, i2_rec in zip(R1, R2, I1, I2):
segments['A'], segments['C'], statusR1 = get_read_codes(i1_rec.sequence, "C", "A")
segments['B'], segments['D'], statusR2 = get_read_codes(i2_rec.sequence, "D", "B")
segments['A'], segments['C'], R1_status = get_read_codes(i1_rec.sequence, "C", "A")
segments['B'], segments['D'], R2_status = get_read_codes(i2_rec.sequence, "D", "B")
if segments[id_letter] not in id_segments:
continue
statuses = [R1_status, R2_status]
BX_code = segments['A'] + segments['C'] + segments['B']+ segments['D']
bc_tags = f"BX:Z:{BX_code}"
if qxrx:
bc_tags = f"RX:Z:{i1_rec.sequence}+{i2_rec.sequence}\tQX:Z:{i1_rec.quality}+{i2_rec.quality}\t{bc_tags}"
r1_rec.comment += f"\t{bc_tags}"
r2_rec.comment += f"\t{bc_tags}"
# search sample name
sample_name = samples_dict.get(segments[id_letter], "unidentified_data")
R1_output[sample_name].write(f"{r1_rec}\n")
R2_output[sample_name].write(f"{r2_rec}\n")
R1_out.write(f"{r1_rec}\n")
R2_out.write(f"{r2_rec}\n")

if (statusR1 == "unclear" or statusR2 == "unclear"):
# logging barcode identification
if "unclear" in statuses:
if BX_code in unclear_read_map:
unclear_read_map[BX_code] += 1
else:
unclear_read_map[BX_code] = 1
else:
if (statusR1 == "corrected" or statusR2 == "corrected"):
if "corrected" in statuses:
if BX_code in clear_read_map:
clear_read_map[BX_code][1] += 1
else:
clear_read_map[BX_code] = [0,1]
else:
if (statusR1 == "found" and statusR2 == "found"):
if all(status == "found" for status in statuses):
if BX_code in clear_read_map:
clear_read_map[BX_code][0] += 1
else:
clear_read_map[BX_code] = [1,0]

for sample_name in samples:
R1_output[sample_name].close()
R2_output[sample_name].close()
clear_read_map[BX_code] = [1,0]

clearBC_log.write("Barcode\tCorrect reads\tCorrected reads\n" )
BC_log.write("Barcode\tTotal_Reads\tCorrect_Reads\tCorrected_Reads\n")
for code in clear_read_map:
clearBC_log.write(code+"\t"+"\t".join(str(x) for x in clear_read_map[code])+"\n")

unclearBC_log.write("Barcode\tReads\n")
BC_log.write(f"{code}\t{sum(clear_read_map[code])}\t{clear_read_map[code][0]}\t{clear_read_map[code][1]}\n")
for code in unclear_read_map:
unclearBC_log.write(code +"\t"+str(unclear_read_map [code])+"\n")
BC_log.write(f"{code}\t{unclear_read_map[code]}\t0\t0\n")
Loading
Loading