Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add time pooled AR #15

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpc/cpc_default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def set_default_cpc_config(parser):
'batchNorm'],
help="Type of normalization to use in the encoder "
"network (default is layerNorm).")
group.add_argument('--paddingMode', type=str, default='zeros',
choices=['zeros', 'reflect', 'replicate',
'circular'],
help="Conv padding mode (default: zeros).")
group.add_argument('--onEncoder', action='store_true',
help="(Supervised mode only) Perform the "
"classification on the encoder's output.")
Expand All @@ -93,6 +97,8 @@ def set_default_cpc_config(parser):
"network (default is lstm).")
group.add_argument('--nLevelsGRU', type=int, default=1,
help='Number of layers in the autoregressive network.')
group.add_argument('--arLenReduction', type=float, default=None,
help='Length reduction of the autreg net.')
group.add_argument('--rnnMode', type=str, default='transformer',
choices=['transformer', 'RNN', 'LSTM', 'linear',
'ffd', 'conv4', 'conv8', 'conv12'],
Expand Down
7 changes: 4 additions & 3 deletions cpc/feature_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def getEncoder(args):
return LFBEnconder(args.hiddenEncoder)
else:
from .model import CPCEncoder
return CPCEncoder(args.hiddenEncoder, args.normMode)
return CPCEncoder(args.hiddenEncoder, args.normMode, paddingMode=args.paddingMode)


def getAR(args):
Expand All @@ -165,7 +165,8 @@ def getAR(args):
args.samplingType == "sequential",
args.nLevelsGRU,
mode=args.arMode,
reverse=args.cpc_mode == "reverse")
reverse=args.cpc_mode == "reverse",
final_lengt_factor=args.arLenReduction)
return arNet


Expand Down Expand Up @@ -368,4 +369,4 @@ def buildFeature_batch(featureMaker, seqPath, strict=False,
out.append(features.detach().cpu())

out = torch.cat(out, dim=1)
return out
return out
117 changes: 107 additions & 10 deletions cpc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class CPCEncoder(nn.Module):

def __init__(self,
sizeHidden=512,
normMode="layerNorm"):
normMode="layerNorm",
paddingMode="zeros"):

super(CPCEncoder, self).__init__()

Expand All @@ -80,16 +81,15 @@ def normLayer(x): return nn.InstanceNorm1d(x, affine=True)
normLayer = nn.BatchNorm1d

self.dimEncoded = sizeHidden
self.conv0 = nn.Conv1d(1, sizeHidden, 10, stride=5, padding=3)
self.conv0 = nn.Conv1d(1, sizeHidden, 10, stride=5, padding=3, padding_mode=paddingMode)
self.batchNorm0 = normLayer(sizeHidden)
self.conv1 = nn.Conv1d(sizeHidden, sizeHidden, 8, stride=4, padding=2)
self.conv1 = nn.Conv1d(sizeHidden, sizeHidden, 8, stride=4, padding=2, padding_mode=paddingMode)
self.batchNorm1 = normLayer(sizeHidden)
self.conv2 = nn.Conv1d(sizeHidden, sizeHidden, 4,
stride=2, padding=1)
self.conv2 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1, padding_mode=paddingMode)
self.batchNorm2 = normLayer(sizeHidden)
self.conv3 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1)
self.conv3 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1, padding_mode=paddingMode)
self.batchNorm3 = normLayer(sizeHidden)
self.conv4 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1)
self.conv4 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1, padding_mode=paddingMode)
self.batchNorm4 = normLayer(sizeHidden)
self.DOWNSAMPLING = 160

Expand Down Expand Up @@ -152,6 +152,78 @@ def forward(self, x):
return x


def sequence_segmenter(encodedData, final_lengt_factor, step_reduction=0.2):
assert not torch.isnan(encodedData).any()
device = encodedData.device
encFlat = F.pad(encodedData.reshape(-1, encodedData.size(-1)).detach(), (0, 0, 1, 0))
feat_csum = encFlat.cumsum(0)
feat_csum2 = (encFlat**2).cumsum(0)
idx = torch.arange(feat_csum.size(0), device=feat_csum.device)

final_length = int(final_lengt_factor * len(encFlat))

while len(idx) > final_length:
begs = idx[:-2]
ends = idx[2:]

sum1 = (feat_csum.index_select(0, ends) - feat_csum.index_select(0, begs))
sum2 = (feat_csum2.index_select(0, ends) - feat_csum2.index_select(0, begs))
num_elem = (ends-begs).float().unsqueeze(1)

diffs = F.pad(torch.sqrt(((sum2/ num_elem - (sum1/ num_elem)**2) ).mean(1)) * num_elem.squeeze(1),
(1,1), value=1e10)

num_to_retain = max(final_length, int(idx.shape[-1] * step_reduction))
_, keep_idx = torch.topk(diffs, num_to_retain)
keep_idx = torch.sort(keep_idx)[0]
idx = idx.index_select(0, keep_idx)

# Ensure that minibatch boundaries are preserved
seq_end_idx = torch.arange(0, encodedData.size(0)*encodedData.size(1), encodedData.size(1), device=device)
idx = torch.unique(torch.cat((idx, seq_end_idx)), sorted=True)

# now work out cut indices in each minibatch element
batch_elem_idx = idx // encodedData.size(1)
transition_idx = F.pad(torch.nonzero(batch_elem_idx[1:] != batch_elem_idx[:-1]), (0,0, 1,0))
cutpoints = (torch.nonzero((idx % encodedData.size(1)) == 0))
compressed_lens = (cutpoints[1:]-cutpoints[:-1]).squeeze(1)

seq_idx = torch.nn.utils.rnn.pad_sequence(
torch.split(idx[1:] % encodedData.size(1), tuple(cutpoints[1:]-cutpoints[:-1])), batch_first=True)
seq_idx[seq_idx==0] = encodedData.size(1)
seq_idx = F.pad(seq_idx, (1,0,0,0))

frame_idxs = torch.arange(encodedData.size(1), device=device).view(1, 1, -1)
compress_matrices = (
(seq_idx[:,:-1, None] <= frame_idxs)
& (seq_idx[:,1:, None] > frame_idxs)
).float()

compressed_lens = compressed_lens.cpu()
assert compress_matrices.shape[0] == encodedData.shape[0]
return compress_matrices, compressed_lens


def compress_batch(encodedData, compress_matrices, compressed_lens, pack=False):
ret = torch.bmm(
compress_matrices / torch.maximum(compress_matrices.sum(-1, keepdim=True), torch.ones(1, device=compress_matrices.device)),
encodedData)
if pack:
ret = torch.nn.utils.rnn.pack_padded_sequence(
ret, compressed_lens, batch_first=True, enforce_sorted=False
)
return ret


def decompress_padded_batch(compressed_data, compress_matrices, compressed_lens):
if isinstance(compressed_data, torch.nn.utils.rnn.PackedSequence):
compressed_data, unused_lens = torch.nn.utils.rnn.pad_packed_sequence(
compressed_data, batch_first=True, total_length=compress_matrices.size(1))
assert (compress_matrices.sum(1) == 1).all()
return torch.bmm(
compress_matrices.transpose(1, 2), compressed_data)


class CPCAR(nn.Module):

def __init__(self,
Expand All @@ -160,7 +232,10 @@ def __init__(self,
keepHidden,
nLevelsGRU,
mode="GRU",
reverse=False):
reverse=False,
final_lengt_factor=None,
step_reduction=0.2
):

super(CPCAR, self).__init__()
self.RESIDUAL_STD = 0.1
Expand All @@ -178,6 +253,17 @@ def __init__(self,
self.hidden = None
self.keepHidden = keepHidden
self.reverse = reverse
self.final_lengt_factor = final_lengt_factor
self.step_reduction = step_reduction

def extra_repr(self):
extras = [
f'reverse={self.reverse}',
f'final_lengt_factor={self.final_lengt_factor}',
f'step_reduction={self.step_reduction}',
]
return ', '.join(extras)


def getDimOutput(self):
return self.baseNet.hidden_size
Expand All @@ -190,10 +276,21 @@ def forward(self, x):
self.baseNet.flatten_parameters()
except RuntimeError:
pass
x, h = self.baseNet(x, self.hidden)

if self.final_lengt_factor is None:
x, h = self.baseNet(x, self.hidden)
else:
compress_matrices, compressed_lens = sequence_segmenter(
x, self.final_lengt_factor, self.step_reduction)
packed_compressed_x = compress_batch(
x, compress_matrices, compressed_lens, pack=True
)
packed_x, packed_h = self.baseNet(packed_compressed_x)
x = decompress_padded_batch(packed_x, compress_matrices, compressed_lens)

if self.keepHidden:
if isinstance(h, tuple):
self.hidden = tuple(x.detach() for x in h)
self.hidden = tuple(h_elem.detach() for h_elem in h)
else:
self.hidden = h.detach()

Expand Down
1 change: 1 addition & 0 deletions train_ls100.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ rsync --exclude '.*' \
$RVERB -lrpt $CPC_DIR/ ${SAVE_DIR}/code/

echo $0 "$@" >> ${SAVE_DIR}/out.txt
hostname >> ${SAVE_DIR}/out.txt
exec python -u cpc/train.py \
--pathDB /pio/data/zerospeech2021/LibriSpeech-wav/train-clean-100 \
--pathTrain /pio/scratch/2/jch/wav2vec/LibriSpeech100_labels_split/train_split.txt \
Expand Down