Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lambeq/experimental/discocirc/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,20 @@ def _tree2sandwiches_rec(self,
noun_box = Box(node.word, Ty(), NOUN)
return Id(NOUN), [noun_box], [node.ind], node.ind

if not isinstance(pruned_ids, set):
pruned_ids = set(pruned_ids)

subdiags = []
nouns = []
noun_inds = []
noun2wire = {}
noun_cursor = previous_noun

if node.typ == NOUN and node.children:
local_head = self._find_head_noun(node, pruned_ids)
if local_head is not None:
noun_cursor = local_head

bigdiag = Id()

if NOUN.l in node.typ or NOUN.r in node.typ:
Expand Down Expand Up @@ -213,6 +221,12 @@ def _tree2sandwiches_rec(self,

ancilla_nouns = ancilla_nouns - {noun2wire[nid]}
wire_ids.append(noun2wire[nid])
if noun2wire.get(nid) is not None:
idx = noun2wire[nid]
if (idx < len(nouns)
and c_nouns[j].name
and not nouns[idx].name):
nouns[idx] = c_nouns[j]

wire_ids = list(sorted(ancilla_nouns))+wire_ids

Expand Down Expand Up @@ -262,6 +276,19 @@ def _tree2sandwiches_rec(self,

return bigdiag, nouns, noun_inds, noun_cursor

def _find_head_noun(self, node, pruned_ids: set[int]) -> int | None:
if node.typ == NOUN and not node.children:
if node.ind in pruned_ids:
return None
return node.ind

for child in node.children:
head = self._find_head_noun(child, pruned_ids)
if head is not None:
return head

return None

def _get_index(self, s, pnoun):
for j, w in enumerate(s):
if w == pnoun:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_discocirc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@
(Box('Bob', Ty(), n) @ Box('Claire', Ty(), n)) >> Box('hates', n @ n, n @ n)
]

def _copular_tree():
return PregroupTreeNode('is', 1, Ty('s'), children=[
PregroupTreeNode('He', 0, n),
PregroupTreeNode('', 2, n, children=[
PregroupTreeNode('very', 3, n, children=[
PregroupTreeNode('talented', 4, n @ n.l),
PregroupTreeNode('programmer', 5, n)
])
])
])


class MockBobcatParser(BobcatParser):
def __init__(self):
Expand Down Expand Up @@ -98,3 +109,15 @@ def test_discocirc_reader_w_different_parsers(monkeypatch):
parser.sentence2diagram.assert_called_once_with(
sentence, tokenised=True,
)


def test_sandwich_prefers_local_head_noun():
parser = MockBobcatParser()
r = DisCoCircReader(parser=parser,
coref_resolver=MockCorefResolver())

tree = _copular_tree()
_, nouns, nids, _ = r._tree2sandwiches_rec(tree, pruned_ids=set())

assert [box.name for box in nouns] == ['He', 'programmer']
assert nids == [0, 5]