Skip to content

Latest commit

 

History

History
124 lines (88 loc) · 3.53 KB

README.md

File metadata and controls

124 lines (88 loc) · 3.53 KB

genbmm

An Extension for Efficient Inside Algorithm

A CUDA kernel extension for efficient Inside Algorithm built on the genbmm library. It computes the inside score on the diagonals iteratively.

Quickstart

pip install git+https://github.com/lyutyuh/genbmm

Usage

genbmm.logbmminside_rule(inside, rule, width)

computes the following values for all $\texttt{col} - \texttt{row} = \texttt{width}$: $$inside[\texttt{batch}, \texttt{row}, \texttt{col}] = rule[\texttt{batch}, \texttt{row}, \texttt{col}] + \log \sum_{i=\texttt{row}}^{\texttt{col}-1} \exp (inside[\texttt{batch}, \texttt{row}, i] + inside[\texttt{batch}, i+1, \texttt{col}]) $$

genbmm.logbmminside(inside, width)

computes the following values for all $\texttt{col} - \texttt{row} = \texttt{width}$: $$inside[\texttt{batch}, \texttt{row}, \texttt{col}] = \log \sum_{i=\texttt{row}}^{\texttt{col}-1} \exp (inside[\texttt{batch}, \texttt{row}, i] + inside[\texttt{batch}, i+1, \texttt{col}]) $$

Example

An example Colab link: Open In Colab.

This kernel is used in structured span selector.

https://github.com/lyutyuh/structured-span-selector/blob/main/outside_mp.py#L62


Old Readme

This library is a collection of missing matrix-multiply like operations for PyTorch. It was developed to provide operators needed for PyTorch-Struct.

The library has currently has two components. It only supports CUDA operations.

  • Generalized matrix-multiplication with gradients (log-space, max, sample)
  • Banded sparse matrices

Quickstart

pip install git+https://github.com/lyutyuh/genbmm

Generalized Matrix Multiplication

Computing matrix multiplies over non-standard semi-rings in PyTorch requires creating large intermediary terms with views. This is particularly bad because they are stored internally for backprop. This library implements some alternative matrix multiplies in CUDA to avoid this issue.

import genbmm

a = torch.rand(10, 3, 4).cuda().requires_grad_(True)
b = torch.rand(10, 4, 5).cuda().requires_grad_(True)

# Log-Sum-Exp
c = genbmm.logbmm(a, b)
# Equivalent
a = a.unsqueeze(-1)
b = b.unsqueeze(-3)
c2 = (a + b).logsumexp(-2)
# Grad
prob_a, prob_b = torch.autograd.grad(c.sum(), (a, b))

# Max
c = genbmm.maxbmm(a, b)
# Equivalent
a = a.unsqueeze(-1)
b = b.unsqueeze(-3)
c2, = (a + b).max(-2)
# Grad
argmax_a, argmax_b = torch.autograd.grad(c.sum(), (a, b))

# Sample
c = genbmm.samplebmm(a, b)
# Equivalent
a = a.unsqueeze(-1)
b = b.unsqueeze(-3)
c2 = (a + b).logsumexp(-2)
# Grad
sample_a, sample_b = torch.autograd.grad(c.sum(), (a, b))
# c2 = (a + b).softmax(-2).sample(-2)

# Product-Max
c = genbmm.prodmaxbmm(a, b)
# Equivalent
a = a.unsqueeze(-1)
b = b.unsqueeze(-3)
c2, = (a * b).max(-2)
# Grad
grad_a, grad_b = torch.autograd.grad(c.sum(), (a, b))

Banded Sparse Matrices

See https://nbviewer.jupyter.org/github/harvardnlp/genbmm/blob/master/notebooks/Sparse.ipynb.