Skip to content

support other dataset types #1134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions v03_pipeline/lib/misc/male_non_par.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import hail as hl

from v03_pipeline.lib.annotations import sv
from v03_pipeline.lib.misc.pedigree import Family
from v03_pipeline.lib.model import DatasetType, ReferenceGenome, Sex


def overwrite_male_non_par_calls(
mt: hl.MatrixTable,
dataset_type: DatasetType,
families: set[Family],
) -> hl.MatrixTable:
male_sample_ids = {
s.sample_id for f in families for s in f.samples.values() if s.sex == Sex.MALE
}
male_sample_ids = (
hl.set(male_sample_ids) if male_sample_ids else hl.empty_set(hl.tstr)
)
par_intervals = hl.array(
[
i
for i in hl.get_reference(ReferenceGenome.GRCh38).par
if i.start.contig == ReferenceGenome.GRCh38.x_contig
],
)
non_par_interval = hl.interval(
par_intervals[0].end,
par_intervals[1].start,
)
# NB: making use of existing formatting_annotation_fns.
# We choose to annotate & drop here as the sample level
# fields are dropped by the time we format variants.
if dataset_type == DatasetType.SV:
mt = mt.annotate_rows(
start_locus=sv.start_locus(mt),
end_locus=sv.end_locus(mt),
)
mt = mt.annotate_entries(
GT=hl.if_else(
(
male_sample_ids.contains(mt.s)
& non_par_interval.overlaps(
hl.interval(
mt.start_locus,
mt.end_locus,
)
if dataset_type == DatasetType.SV
else hl.interval(mt.locus.position, mt.locus.position),
)
& mt.GT.is_het()
),
hl.Call([1], phased=False),
mt.GT,
),
)
if dataset_type == DatasetType.SV:
return mt.drop('start_locus', 'end_locus')
return mt
64 changes: 64 additions & 0 deletions v03_pipeline/lib/misc/male_non_par_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest

import hail as hl

from v03_pipeline.lib.misc.io import import_callset, select_relevant_fields
from v03_pipeline.lib.misc.male_non_par import (
overwrite_male_non_par_calls,
)
from v03_pipeline.lib.misc.pedigree import Family, Sample
from v03_pipeline.lib.misc.sample_ids import subset_samples
from v03_pipeline.lib.model import DatasetType, ReferenceGenome, Sex

TEST_SV_VCF = 'v03_pipeline/var/test/callsets/sv_1.vcf'


class MaleNonParTest(unittest.TestCase):
def test_overwrite_male_non_par_calls(self) -> None:
mt = import_callset(TEST_SV_VCF, ReferenceGenome.GRCh38, DatasetType.SV)
mt = select_relevant_fields(
mt,
DatasetType.SV,
)
mt = subset_samples(
mt,
hl.Table.parallelize(
[{'s': sample_id} for sample_id in ['RGP_164_1', 'RGP_164_2']],
hl.tstruct(s=hl.dtype('str')),
key='s',
),
)
mt = overwrite_male_non_par_calls(
mt,
DatasetType.SV,
{
Family(
family_guid='family_1',
samples={
'RGP_164_1': Sample(sample_id='RGP_164_1', sex=Sex.FEMALE),
'RGP_164_2': Sample(sample_id='RGP_164_2', sex=Sex.MALE),
},
),
},
)
mt = mt.filter_rows(mt.locus.contig == 'chrX')
self.assertEqual(
[
hl.Locus(contig='chrX', position=3, reference_genome='GRCh38'),
hl.Locus(contig='chrX', position=2781700, reference_genome='GRCh38'),
],
mt.locus.collect(),
)
self.assertEqual(
[
hl.Call(alleles=[0, 0], phased=False),
# END of this variant < start of the non-par region.
hl.Call(alleles=[0, 1], phased=False),
hl.Call(alleles=[0, 0], phased=False),
hl.Call(alleles=[1], phased=False),
],
mt.GT.collect(),
)
self.assertFalse(
hasattr(mt, 'start_locus'),
)
49 changes: 0 additions & 49 deletions v03_pipeline/lib/misc/sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import hail as hl

from v03_pipeline.lib.annotations import sv
from v03_pipeline.lib.misc.pedigree import Family
from v03_pipeline.lib.model import ReferenceGenome, Sex

WRONG_CHROM_PENALTY = 1e9

Expand Down Expand Up @@ -106,50 +104,3 @@ def deduplicate_merged_sv_concordance_calls(
),
},
)


def overwrite_male_non_par_calls(
mt: hl.MatrixTable,
families: set[Family],
) -> hl.MatrixTable:
male_sample_ids = {
s.sample_id for f in families for s in f.samples.values() if s.sex == Sex.MALE
}
male_sample_ids = (
hl.set(male_sample_ids) if male_sample_ids else hl.empty_set(hl.tstr)
)
par_intervals = hl.array(
[
i
for i in hl.get_reference(ReferenceGenome.GRCh38).par
if i.start.contig == ReferenceGenome.GRCh38.x_contig
],
)
non_par_interval = hl.interval(
par_intervals[0].end,
par_intervals[1].start,
)
# NB: making use of existing formatting_annotation_fns.
# We choose to annotate & drop here as the sample level
# fields are dropped by the time we format variants.
mt = mt.annotate_rows(
start_locus=sv.start_locus(mt),
end_locus=sv.end_locus(mt),
)
mt = mt.annotate_entries(
GT=hl.if_else(
(
male_sample_ids.contains(mt.s)
& non_par_interval.overlaps(
hl.interval(
mt.start_locus,
mt.end_locus,
),
)
& mt.GT.is_het()
),
hl.Call([1], phased=False),
mt.GT,
),
)
return mt.drop('start_locus', 'end_locus')
53 changes: 0 additions & 53 deletions v03_pipeline/lib/misc/sv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@

import hail as hl

from v03_pipeline.lib.misc.io import import_callset, select_relevant_fields
from v03_pipeline.lib.misc.pedigree import Family, Sample
from v03_pipeline.lib.misc.sample_ids import subset_samples
from v03_pipeline.lib.misc.sv import (
deduplicate_merged_sv_concordance_calls,
overwrite_male_non_par_calls,
)
from v03_pipeline.lib.model import DatasetType, ReferenceGenome, Sex

TEST_SV_VCF = 'v03_pipeline/var/test/callsets/sv_1.vcf'
ANNOTATIONS_HT = hl.Table.parallelize(
Expand Down Expand Up @@ -52,54 +47,6 @@


class SVTest(unittest.TestCase):
def test_overwrite_male_non_par_calls(self) -> None:
mt = import_callset(TEST_SV_VCF, ReferenceGenome.GRCh38, DatasetType.SV)
mt = select_relevant_fields(
mt,
DatasetType.SV,
)
mt = subset_samples(
mt,
hl.Table.parallelize(
[{'s': sample_id} for sample_id in ['RGP_164_1', 'RGP_164_2']],
hl.tstruct(s=hl.dtype('str')),
key='s',
),
)
mt = overwrite_male_non_par_calls(
mt,
{
Family(
family_guid='family_1',
samples={
'RGP_164_1': Sample(sample_id='RGP_164_1', sex=Sex.FEMALE),
'RGP_164_2': Sample(sample_id='RGP_164_2', sex=Sex.MALE),
},
),
},
)
mt = mt.filter_rows(mt.locus.contig == 'chrX')
self.assertEqual(
[
hl.Locus(contig='chrX', position=3, reference_genome='GRCh38'),
hl.Locus(contig='chrX', position=2781700, reference_genome='GRCh38'),
],
mt.locus.collect(),
)
self.assertEqual(
[
hl.Call(alleles=[0, 0], phased=False),
# END of this variant < start of the non-par region.
hl.Call(alleles=[0, 1], phased=False),
hl.Call(alleles=[0, 0], phased=False),
hl.Call(alleles=[1], phased=False),
],
mt.GT.collect(),
)
self.assertFalse(
hasattr(mt, 'start_locus'),
)

def test_deduplicate_merged_sv_concordance_calls(self) -> None:
mt = (
hl.MatrixTable.from_parts(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import_pedigree,
remap_pedigree_hash,
)
from v03_pipeline.lib.misc.male_non_par import overwrite_male_non_par_calls
from v03_pipeline.lib.misc.pedigree import (
parse_pedigree_ht_to_families,
parse_pedigree_ht_to_remap_ht,
)
from v03_pipeline.lib.misc.sample_ids import remap_sample_ids, subset_samples
from v03_pipeline.lib.misc.sv import overwrite_male_non_par_calls
from v03_pipeline.lib.misc.validation import SeqrValidationError
from v03_pipeline.lib.model.feature_flag import FeatureFlag
from v03_pipeline.lib.paths import (
Expand Down Expand Up @@ -203,7 +203,7 @@ def create_table(self) -> hl.MatrixTable:
mt = mt.drop(field)

if self.dataset_type.overwrite_male_non_par_calls:
mt = overwrite_male_non_par_calls(mt, loadable_families)
mt = overwrite_male_non_par_calls(mt, self.dataset_type, loadable_families)
return mt.select_globals(
remap_pedigree_hash=remap_pedigree_hash(
self.project_pedigree_paths[self.project_i],
Expand Down