From 59a4403a532d8d521198baa10e9e589c22e55130 Mon Sep 17 00:00:00 2001 From: Kyle Bittinger Date: Thu, 25 Apr 2019 12:37:56 -0400 Subject: [PATCH] Add ability to write table of total read counts --- dnabclib/main.py | 5 +++++ dnabclib/writer.py | 5 +++++ test/test_main.py | 8 ++++++++ test/test_writer.py | 6 ++++++ 4 files changed, 24 insertions(+) diff --git a/dnabclib/main.py b/dnabclib/main.py index fe94dda..a2f5c5f 100644 --- a/dnabclib/main.py +++ b/dnabclib/main.py @@ -61,6 +61,9 @@ def main(argv=None): p.add_argument( "--manifest-file", type=argparse.FileType("w"), help=( "Write manifest file for QIIME2")) + p.add_argument( + "--total-reads-file", type=argparse.FileType("w"), help=( + "Write TSV table of total read counts")) args = p.parse_args(argv) samples = list(Sample.load(args.barcode_file)) @@ -78,3 +81,5 @@ def main(argv=None): if args.manifest_file: writer.write_qiime2_manifest(args.manifest_file) + if args.total_reads_file: + writer.write_read_counts(args.total_reads_file, assigner.read_counts) diff --git a/dnabclib/writer.py b/dnabclib/writer.py index e7bd65a..7d99b6a 100644 --- a/dnabclib/writer.py +++ b/dnabclib/writer.py @@ -30,6 +30,11 @@ def write_qiime2_manifest(self, f): fp1 = os.path.abspath(f1.name) f.write("{0},{1},forward\n".format(sample.name, fp1)) + def write_read_counts(self, f, read_counts): + f.write("SampleID\tNumReads\n") + for sample_name, n in read_counts.items(): + f.write("{0}\t{1}\n".format(sample_name, n)) + def _get_output_file(self, sample): f = self._open_files.get(sample) if f is None: diff --git a/test/test_main.py b/test/test_main.py index 26483a5..206d27b 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -46,6 +46,7 @@ def setUp(self): self.output_dir = os.path.join(self.temp_dir, "output") self.summary_fp = os.path.join(self.temp_dir, "summary.json") self.manifest_fp = os.path.join(self.temp_dir, "manifest.csv") + self.total_reads_fp = os.path.join(self.temp_dir, "read_counts.tsv") def tearDown(self): shutil.rmtree(self.temp_dir) @@ -58,6 +59,7 @@ def test_regular(self): "--barcode-file", self.barcode_fp, "--output-dir", self.output_dir, "--manifest-file", self.manifest_fp, + "--total-reads-file", self.total_reads_fp, "--revcomp", ]) self.assertEqual( @@ -72,6 +74,12 @@ def test_regular(self): self.assertIn(direction, ["forward\n", "reverse\n"]) self.assertIn(sample_id, ["SampleA", "SampleB"]) + with open(self.total_reads_fp) as f: + self.assertEqual(next(f), "SampleID\tNumReads\n") + self.assertEqual(next(f), "SampleA\t1\n") + self.assertEqual(next(f), "SampleB\t1\n") + self.assertEqual(next(f), "unassigned\t1\n") + class SampleNameTests(unittest.TestCase): def test_get_sample_names_main(self): diff --git a/test/test_writer.py b/test/test_writer.py index 071e3a3..a318d04 100644 --- a/test/test_writer.py +++ b/test/test_writer.py @@ -64,6 +64,12 @@ def test_write(self): "h56,{0},forward\n".format(fp), ]) + f2 = MockFile() + w.write_read_counts(f2, {"s1": 365}) + self.assertEqual(f2.contents, [ + "SampleID\tNumReads\n", + "s1\t365\n" + ]) class MockFile: def __init__(self):