Skip to content

Commit

Permalink
Add seqio index genome loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ThijsMaas committed Mar 18, 2024
1 parent e2bc0d8 commit cd2e1a6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 68 deletions.
67 changes: 34 additions & 33 deletions iss/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import os
import sys

from Bio import SeqIO

from iss import util
from iss.generator import (
generate_work_divider,
Expand Down Expand Up @@ -71,39 +69,42 @@ def generate_reads(args):
try:
# list holding the prefix for each cpu's temp file
temp_file_list = [f"{args.output}.iss.tmp.{i}" for i in range(args.cpus)]
f = open(genome_file, "r") # re-opens the file
with f:
fasta_file = SeqIO.parse(f, "fasta")
# TODO check if a SeqIO.index (db) leads to better memory usage and is not slower
# fasta_dict = SeqIO.index(f, 'fasta')

# Calculate how many reads we want each cpu to generate
n_read_pairs = n_reads // 2
chunk_size = -((n_read_pairs) // -args.cpus) # this is ceildiv, see https://stackoverflow.com/a/17511341
logger.debug("Chunk size: %s" % chunk_size)
# Calculate how many reads we want each cpu to generate
n_read_pairs = n_reads // 2
chunk_size = -((n_read_pairs) // -args.cpus) # this is ceildiv, see https://stackoverflow.com/a/17511341
logger.debug("Chunk size: %s" % chunk_size)

# Divide the work of generating n_reads for each record into chunks
work_chunks = generate_work_divider(
fasta_file,
readcount_dic,
abundance_dic,
n_reads,
args.coverage,
args.coverage_file,
error_model,
args.output,
chunk_size,
)
# Divide the work of generating n_reads for each record into chunks
work_chunks = generate_work_divider(
genome_file,
readcount_dic,
abundance_dic,
n_reads,
args.coverage,
args.coverage_file,
error_model,
chunk_size,
)

# Generate reads for each chunk in parallel
with mp.Pool(args.cpus) as pool:
pool.starmap(
worker_iterator,
[
(work, error_model, cpu_number, worker_prefix, args.seed, args.sequence_type, args.gc_bias)
for cpu_number, (work, worker_prefix) in enumerate(zip(work_chunks, temp_file_list))
],
)
# Generate reads for each chunk in parallel
with mp.Pool(args.cpus) as pool:
pool.starmap(
worker_iterator,
[
(
work,
error_model,
cpu_number,
worker_prefix,
args.seed,
args.sequence_type,
args.gc_bias,
genome_file,
)
for cpu_number, (work, worker_prefix) in enumerate(zip(work_chunks, temp_file_list))
],
)

except KeyboardInterrupt as e:
logger.error("iss generate interrupted: %s" % e)
Expand All @@ -122,7 +123,7 @@ def generate_reads(args):
# and reads were appended to the same temp file.
temp_R1 = [temp_file + "_R1.fastq" for temp_file in temp_file_list]
temp_R2 = [temp_file + "_R2.fastq" for temp_file in temp_file_list]
temp_mut = [temp_file + ".vcf" for temp_file in temp_file_list] if args.store_mutations else []
temp_mut = [temp_file + ".vcf" for temp_file in temp_file_list]
util.concatenate(temp_R1, args.output + "_R1.fastq")
util.concatenate(temp_R2, args.output + "_R2.fastq")
if args.store_mutations:
Expand Down
54 changes: 19 additions & 35 deletions iss/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import gc
import logging
import os
import random
Expand All @@ -15,7 +14,7 @@

from iss import abundance, download, util
from iss.error_models import basic, kde, perfect
from iss.util import load, rev_comp
from iss.util import rev_comp


def simulate_reads(
Expand All @@ -28,7 +27,6 @@ def simulate_reads(
mutations_handle,
sequence_type,
gc_bias=False,
mode="default",
):
"""Simulate reads from one genome (or sequence) according to an ErrorModel
Expand All @@ -51,11 +49,6 @@ def simulate_reads(
"""
logger = logging.getLogger(__name__)

# load the record from disk if mode is memmap
if mode == "memmap":
record_mmap = load(record)
record = record_mmap

logger.debug("Cpu #%s: Generating %s read pairs" % (cpu_number, n_pairs))

for forward_record, reverse_record, mutations in reads_generator(
Expand Down Expand Up @@ -220,7 +213,7 @@ def to_fastq(generator, output):
SeqIO.write(read_tuple[1], r, "fastq-sanger")


def worker_iterator(work, error_model, cpu_number, worker_prefix, seed, sequence_type, gc_bias):
def worker_iterator(work, error_model, cpu_number, worker_prefix, seed, sequence_type, gc_bias, genome_file):
"""A utility function to run the reads simulation of each record in a loop for a specific cpu"""
logger = logging.getLogger(__name__)
try:
Expand All @@ -234,13 +227,14 @@ def worker_iterator(work, error_model, cpu_number, worker_prefix, seed, sequence
if seed is not None:
random.seed(seed + cpu_number)
np.random.seed(seed + cpu_number)
fasta_index = SeqIO.index(genome_file, "fasta")

with forward_handle, reverse_handle, mutation_handle:
for record, n_pairs, mode in work:
for record_id, n_pairs in work:
record = fasta_index[record_id]
simulate_reads(
record=record,
n_pairs=n_pairs,
mode=mode,
error_model=error_model,
cpu_number=cpu_number,
forward_handle=forward_handle,
Expand All @@ -252,7 +246,14 @@ def worker_iterator(work, error_model, cpu_number, worker_prefix, seed, sequence


def generate_work_divider(
fasta_file, readcount_dic, abundance_dic, n_reads, coverage, coverage_file, error_model, output, chunk_size
genome_file,
readcount_dic,
abundance_dic,
n_reads,
coverage,
coverage_file,
error_model,
chunk_size,
):
"""Yields a list of tuples containing the records and the number of reads to generate for each record
Expand All @@ -267,8 +268,9 @@ def generate_work_divider(
total_reads_generated_unrounded = 0

chunk_work = []
fasta_index = SeqIO.index(genome_file, "fasta")

for record in fasta_file:
for record in fasta_index.values():
# generate reads for records
if readcount_dic is not None:
if record.id not in readcount_dic:
Expand Down Expand Up @@ -310,37 +312,17 @@ def generate_work_divider(
if n_pairs == 0:
continue

# due to a bug in multiprocessing
# https://bugs.python.org/issue17560
# we can't send records taking more than 2**31 bytes
# through serialisation.
# In those cases we use memmapping
if sys.getsizeof(str(record.seq)) >= 2**31 - 1:
logger.warning("record %s unusually big." % record.id)
logger.warning("Using a memory map.")
mode = "memmap"

record_mmap = "%s.memmap" % output
if os.path.exists(record_mmap):
os.unlink(record_mmap)
util.dump(record, record_mmap)
del record
record = record_mmap
gc.collect()
else:
mode = "default"

n_pairs_remaining = n_pairs
while n_pairs_remaining > 0:
chunk_remaining = chunk_size - current_chunk

if n_pairs_remaining <= chunk_remaining:
# Record fits in the current chunk
chunk_work.append((record, n_pairs_remaining, mode))
chunk_work.append((record.id, n_pairs_remaining))
n_pairs_added = n_pairs_remaining
else:
# Record does not fit in the current chunk
chunk_work.append((record, chunk_remaining, mode))
chunk_work.append((record.id, chunk_remaining))
n_pairs_added = chunk_remaining

n_pairs_remaining -= n_pairs_added
Expand All @@ -351,6 +333,8 @@ def generate_work_divider(
chunk_work = []
current_chunk = 0

fasta_index.close()

if chunk_work:
# Yield the last (not full) chunk
yield chunk_work
Expand Down

0 comments on commit cd2e1a6

Please sign in to comment.