Skip to content

Commit

Permalink
corrected typo
Browse files Browse the repository at this point in the history
  • Loading branch information
tmb committed Sep 5, 2013
1 parent 6b9044a commit 9b0aee5
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion ocrolib/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ def forward_algorithm(match,skip=-5.0):
correspondence probabilities."""
v = skip*arange(len(match[0]))
result = []
# This is a fairly straightforward dynamic programming problem and
# implemented in close analogy to the edit distance:
# we either stay in the same state at no extra cost or make a diagonal
# step (transition into new state) at no extra cost; the only costs come
# from how well the symbols match the network output.
for i in range(0,len(match)):
w = roll(v,1).copy()
w[0] = skip*i
Expand All @@ -755,29 +760,51 @@ def forwardbackward(lmatch):
"""Apply the forward-backward algorithm to an array of log state
correspondence probabilities."""
lr = forward_algorithm(lmatch)
# backward is just forward applied to the reversed sequence
rl = forward_algorithm(lmatch[::-1,::-1])[::-1,::-1]
both = lr+rl
return both

def ctc_align_targets(outputs,targets,threshold=100.0,verbose=0,debug=0,lo=1e-5):
"""Perform alignment between the `outputs` of a neural network
classifier and a list of `targets`."""
classifier and some targets. The targets themselves are a time sequence
of vectors, usually a unary representation of each target class (but
possibly sequences of arbitrary posterior probability distributions
represented as vectors)."""

outputs = maximum(lo,outputs)
outputs = outputs * 1.0/sum(outputs,axis=1)[:,newaxis]

# first, we compute the match between the outputs and the targets
# and put the result in the log domain
match = dot(outputs,targets.T)
lmatch = log(match)

if debug:
figure("ctcalign"); clf();
subplot(411); imshow(outputs.T,interpolation='nearest',cmap=cm.hot)
subplot(412); imshow(lmatch.T,interpolation='nearest',cmap=cm.hot)
assert not isnan(lmatch).any()

# Now, we compute a forward-backward algorithm over the matches between
# the input and the output states.
both = forwardbackward(lmatch)

# We need posterior probabilities for the states, so we need to normalize
# the output. Instead of keeping track of the normalization
# factors, we just normalize the posterior distribution directly.
epath = exp(both-amax(both))
l = sum(epath,axis=0)[newaxis,:]
epath /= where(l==0.0,1e-9,l)

# The previous computation gives us an alignment between input time
# and output sequence position; however, we actually want the posterior
# probability distribution at each time step. This dot product gives
# us that result. We renormalize again afterwards.
aligned = maximum(lo,dot(epath,targets))
l = sum(aligned,axis=1)[:,newaxis]
aligned /= where(l==0.0,1e-9,l)

if debug:
subplot(413); imshow(epath.T,cmap=cm.hot,interpolation='nearest')
subplot(414); imshow(aligned.T,cmap=cm.hot,interpolation='nearest')
Expand Down

0 comments on commit 9b0aee5

Please sign in to comment.