Skip to content

Commit c87585d

Browse files
authored
Add files via upload
1 parent d4808ec commit c87585d

File tree

1 file changed

+338
-0
lines changed

1 file changed

+338
-0
lines changed

Dictionary_learning_v2.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Mon Jun 05 23:31:51 2017
4+
5+
@author: Rehan Ahmad
6+
7+
Back Tracking Line Search taken from:
8+
http://users.ece.utexas.edu/~cmcaram/EE381V_2012F/Lecture_4_Scribe_Notes.final.pdf
9+
10+
"""
11+
import numpy as np
12+
from sklearn import preprocessing
13+
import matplotlib.pylab as plt
14+
from copy import deepcopy
15+
import time
16+
from omp import omp
17+
from KSVD import KSVD
18+
from FindDistanceBetweenDictionaries import FindDistanceBetweenDictionaries
19+
from DictUpdate03 import DictUpdate03
20+
import pdb
21+
22+
def awgn(x,snr_db):
23+
L = len(x)
24+
Es = np.sum(np.abs(x)**2)/L
25+
snr_lin = 10**(snr_db/10.0)
26+
noise = np.sqrt(Es/snr_lin)*np.random.randn(L)
27+
y = x + noise
28+
return y
29+
30+
if __name__ == "__main__":
31+
tic = time.time()
32+
33+
FlagPGD = True; FlagPGDMom = True; FlagMOD = False; FlagKSVD = True;
34+
FlagRSimCo = True; FlagPSimCo = True; FlagGDBTLS = True; FlagRGDBTLS = True
35+
36+
drows = 16 #20 #16
37+
dcols = 32 #50 #32
38+
ycols = 78 #1500 #78
39+
alpha = 0.005
40+
41+
iterations = 1000
42+
SNR = 20
43+
epochs = 1
44+
sparsity = 4
45+
46+
count_success = np.ndarray((iterations,epochs))
47+
count_success_momen = np.ndarray((iterations,epochs))
48+
count_success_MOD = np.ndarray((iterations,epochs))
49+
count_success_KSVD = np.ndarray((iterations,epochs))
50+
count_success_RSimCo = np.ndarray((iterations,epochs))
51+
count_success_PSimCo = np.ndarray((iterations,epochs))
52+
count_success_GDBTLS = np.ndarray((iterations,epochs))
53+
count_success_RGDBTLS = np.ndarray((iterations,epochs))
54+
55+
e = np.ndarray((iterations,epochs))
56+
e_momen = np.ndarray((iterations,epochs))
57+
e_GDBTLS = np.ndarray((iterations,epochs))
58+
e_MOD = np.ndarray((iterations,epochs))
59+
e_KSVD = np.ndarray((iterations,epochs))
60+
e_RSimCo = np.ndarray((iterations,epochs))
61+
e_PSimCo = np.ndarray((iterations,epochs))
62+
e_RGDBTLS = np.ndarray((iterations,epochs))
63+
64+
for epoch in range(epochs):
65+
alpha = 0.005
66+
# np.random.seed(epoch)
67+
68+
################# make initial dictionary #############################
69+
# Pn=ceil(sqrt(K));
70+
# DCT=zeros(bb,Pn);
71+
# for k=0:1:Pn-1,
72+
# V=cos([0:1:bb-1]'*k*pi/Pn);
73+
# if k>0, V=V-mean(V); end;
74+
# DCT(:,k+1)=V/norm(V);
75+
# end;
76+
# DCT=kron(DCT,DCT);
77+
######################################################################
78+
79+
# Creating dictionary from uniform iid random distribution
80+
# and normalizing atoms by l2-norm
81+
D = np.random.rand(drows,dcols)
82+
D = preprocessing.normalize(D,norm='l2',axis=0)
83+
# Creating data Y by linear combinations of randomly selected
84+
# atoms and iid uniform coefficients
85+
Y = np.ndarray((drows,ycols))
86+
for i in range(ycols):
87+
PermIndx = np.random.permutation(dcols)
88+
Y[:,i] = np.random.rand()*D[:,PermIndx[0]] + \
89+
np.random.rand()*D[:,PermIndx[1]] + \
90+
np.random.rand()*D[:,PermIndx[2]] + \
91+
np.random.rand()*D[:,PermIndx[3]]
92+
93+
# Add awgn noise in data Y
94+
# for i in range(ycols):
95+
# Y[:,i] = awgn(Y[:,i],SNR)
96+
97+
Dhat = np.ndarray((drows,dcols))
98+
Dhat = deepcopy(Y[:,np.random.permutation(ycols)[0:dcols]])
99+
Dhat = preprocessing.normalize(Dhat,norm='l2',axis=0)
100+
Dhat_momen = deepcopy(Dhat)
101+
Dhat_MOD = deepcopy(Dhat)
102+
Dhat_KSVD = deepcopy(Dhat)
103+
Dhat_RSimCo = deepcopy(Dhat)
104+
Dhat_PSimCo = deepcopy(Dhat)
105+
Dhat_GDBTLS = deepcopy(Dhat)
106+
Dhat_RGDBTLS = deepcopy(Dhat)
107+
108+
########################################################
109+
# Applying Projected Gradient Descent without momentum #
110+
########################################################
111+
if(FlagPGD==True):
112+
X = omp(D,Y,sparsity)
113+
for j in range(iterations):
114+
# X = omp(Dhat,Y,sparsity)
115+
# for i in range(dcols):
116+
# R = Y-np.dot(Dhat,X)
117+
# Dhat[:,i] = Dhat[:,i] + alpha*np.dot(R,X[i,:])
118+
Dhat = Dhat + alpha*np.dot(Y-np.dot(Dhat,X),X.T) #Parallel dictionary update...
119+
Dhat = preprocessing.normalize(Dhat,norm='l2',axis=0)
120+
121+
e[j,epoch] = np.linalg.norm(Y-np.dot(Dhat,X),'fro')**2
122+
count = FindDistanceBetweenDictionaries(D,Dhat)
123+
count_success[j,epoch] = count
124+
#####################################################
125+
# Applying Projected Gradient Descent with momentum #
126+
#####################################################
127+
if(FlagPGDMom==True):
128+
v = np.zeros((drows,dcols))
129+
gamma = 0.5
130+
X = omp(D,Y,sparsity)
131+
for j in range(iterations):
132+
# X = omp(Dhat_momen,Y,sparsity)
133+
# for i in range(dcols):
134+
# R = Y-np.dot(Dhat_momen,X)
135+
# v[:,i] = gamma*v[:,i] + alpha*np.dot(R,X[i,:])
136+
# Dhat_momen[:,i] = Dhat_momen[:,i] + v[:,i]
137+
v = gamma*v - alpha*np.dot(Y-np.dot(Dhat_momen,X),X.T)
138+
Dhat_momen = Dhat_momen - v
139+
140+
Dhat_momen = preprocessing.normalize(Dhat_momen,norm='l2',axis=0)
141+
e_momen[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_momen,X),'fro')**2
142+
count_momen = FindDistanceBetweenDictionaries(D,Dhat_momen)
143+
count_success_momen[j,epoch] = count_momen
144+
#####################################################
145+
# Applying Gradient Descent with back tracking line #
146+
# search algorithm #
147+
#####################################################
148+
if(FlagGDBTLS==True):
149+
alpha = 1
150+
beta = np.random.rand()
151+
eta = np.random.rand()*0.5
152+
Grad = np.zeros((drows,dcols))
153+
154+
X = omp(D,Y,sparsity)
155+
for j in range(iterations):
156+
alpha = 1
157+
# X = omp(Dhat_GDBTLS,Y,sparsity)
158+
Dhat_GDtemp = deepcopy(Dhat_GDBTLS)
159+
160+
#################################################################
161+
# Back Tracking line search Algorithm (BTLS) to find optimal #
162+
# value of alpha #
163+
#################################################################
164+
Grad = -np.dot(Y-np.dot(Dhat_GDBTLS,X),X.T)
165+
oldfunc = np.linalg.norm(Y-np.dot(Dhat_GDBTLS,X),'fro')**2
166+
newfunc = np.linalg.norm(Y-np.dot(Dhat_GDtemp,X),'fro')**2
167+
while(~(newfunc <= oldfunc-eta*alpha*np.sum(Grad**2))):
168+
alpha = beta*alpha
169+
Dhat_GDtemp = deepcopy(Dhat_GDBTLS)
170+
Dhat_GDtemp = Dhat_GDtemp + alpha*np.dot(Y-np.dot(Dhat_GDtemp,X),X.T)
171+
Dhat_GDtemp = preprocessing.normalize(Dhat_GDtemp,norm='l2',axis=0)
172+
newfunc = np.linalg.norm(Y-np.dot(Dhat_GDtemp,X),'fro')**2
173+
if(alpha < 1e-9):
174+
break
175+
#################################################################
176+
#################################################################
177+
Dhat_GDBTLS = Dhat_GDBTLS + alpha*np.dot(Y-np.dot(Dhat_GDBTLS,X),X.T)
178+
Dhat_GDBTLS = preprocessing.normalize(Dhat_GDBTLS,norm='l2',axis=0)
179+
180+
e_GDBTLS[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_GDBTLS,X),'fro')**2
181+
count_GDBTLS = FindDistanceBetweenDictionaries(D,Dhat_GDBTLS)
182+
count_success_GDBTLS[j,epoch] = count_GDBTLS
183+
184+
#####################################################
185+
# Applying Gradient Descent with back tracking line #
186+
# search algorithm with regularization on X #
187+
#####################################################
188+
if(FlagRGDBTLS==True):
189+
alpha = 1
190+
mu = 0.01
191+
# beta = np.random.rand()
192+
# eta = np.random.rand()*0.5
193+
# Grad = np.zeros((drows,dcols))
194+
# mu = 0.01
195+
196+
X = omp(D,Y,sparsity)
197+
for j in range(iterations):
198+
alpha = 1
199+
# X = omp(Dhat_RGDBTLS,Y,sparsity)
200+
Dhat_RGDtemp = deepcopy(Dhat_RGDBTLS)
201+
202+
#################################################################
203+
# Back Tracking line search Algorithm (BTLS) to find optimal #
204+
# value of alpha #
205+
#################################################################
206+
Grad = -np.dot(Y-np.dot(Dhat_RGDBTLS,X),X.T)
207+
oldfunc = np.linalg.norm(Y-np.dot(Dhat_RGDBTLS,X),'fro')**2 + mu*np.linalg.norm(X,'fro')**2
208+
newfunc = np.linalg.norm(Y-np.dot(Dhat_RGDtemp,X),'fro')**2 + mu*np.linalg.norm(X,'fro')**2
209+
while(~(newfunc <= oldfunc-eta*alpha*np.sum(Grad**2))):
210+
alpha = beta*alpha
211+
Dhat_RGDtemp = deepcopy(Dhat_RGDBTLS)
212+
Dhat_RGDtemp = Dhat_RGDtemp + alpha*np.dot(Y-np.dot(Dhat_RGDtemp,X),X.T)
213+
Dhat_RGDtemp = preprocessing.normalize(Dhat_RGDtemp,norm='l2',axis=0)
214+
newfunc = np.linalg.norm(Y-np.dot(Dhat_RGDtemp,X),'fro')**2 + mu*np.linalg.norm(X,'fro')**2
215+
if(alpha < 1e-9):
216+
break
217+
#################################################################
218+
#################################################################
219+
Dhat_RGDBTLS = Dhat_RGDBTLS + alpha*np.dot(Y-np.dot(Dhat_RGDBTLS,X),X.T)
220+
Dhat_RGDBTLS = preprocessing.normalize(Dhat_RGDBTLS,norm='l2',axis=0)
221+
########## Update X Considering same sparsity pattern############
222+
Omega = X!=0
223+
ColUpdate = np.sum(Omega,axis=0)!=0
224+
YI = deepcopy(Y[:,ColUpdate])
225+
DI = deepcopy(Dhat_RGDBTLS)
226+
XI = deepcopy(X[:,ColUpdate])
227+
OmegaI = deepcopy(Omega[:,ColUpdate])
228+
OmegaL = np.sum(Omega,axis=0)
229+
mu_sqrt = np.sqrt(mu)
230+
231+
for cn in range(ycols):
232+
L = deepcopy(OmegaL[cn])
233+
X[OmegaI[:,cn],cn] = np.linalg.lstsq(np.append(DI[:,OmegaI[:,cn]],\
234+
np.diag(mu_sqrt*np.ones((L,))),axis=0),\
235+
np.append(YI[:,cn],np.zeros((L,)),axis=0))[0]
236+
#################################################################
237+
e_RGDBTLS[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_RGDBTLS,X),'fro')**2
238+
count_RGDBTLS = FindDistanceBetweenDictionaries(D,Dhat_RGDBTLS)
239+
count_success_RGDBTLS[j,epoch] = count_RGDBTLS
240+
############################################
241+
# Applying MOD Algorithm #
242+
############################################
243+
if(FlagMOD==True):
244+
X = omp(D,Y,sparsity)
245+
for j in range(iterations):
246+
# X = omp(Dhat_MOD,Y,sparsity)
247+
Dhat_MOD = np.dot(Y,np.linalg.pinv(X))
248+
Dhat_MOD = preprocessing.normalize(Dhat_MOD,norm='l2',axis=0)
249+
250+
count_MOD = FindDistanceBetweenDictionaries(D,Dhat_MOD)
251+
count_success_MOD[j,epoch] = count_MOD
252+
e_MOD[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_MOD,X),'fro')**2
253+
############################################
254+
# Applying KSVD Algorithm #
255+
############################################
256+
if(FlagKSVD==True):
257+
X = omp(D,Y,sparsity)
258+
for j in range(iterations):
259+
# X = omp(Dhat_KSVD,Y,sparsity)
260+
Dhat_KSVD,X = KSVD(Y,Dhat_KSVD,X)
261+
262+
count_KSVD = FindDistanceBetweenDictionaries(D,Dhat_KSVD)
263+
count_success_KSVD[j,epoch] = count_KSVD
264+
e_KSVD[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_KSVD,X),'fro')**2
265+
266+
#############################################
267+
# Applying Regularized SimCo Algorithm #
268+
#############################################
269+
if(FlagRSimCo==True):
270+
class IPara():
271+
pass
272+
IPara = IPara()
273+
IPara.I = range(D.shape[1])
274+
IPara.mu = 0.01
275+
IPara.dispN = 20
276+
IPara.DebugFlag = 0
277+
IPara.itN = 1
278+
IPara.gmin = 1e-5; # the minimum value of gradient
279+
IPara.Lmin = 1e-6; # t4-t1 should be larger than Lmin
280+
IPara.t4 = 1e-2; # the initial value of t4
281+
IPara.rNmax = 3; # the number of iterative refinement in Part B in DictLineSearch03.m
282+
283+
X = omp(D,Y,sparsity)
284+
for j in range(iterations):
285+
# X = omp(Dhat_RSimCo,Y,sparsity)
286+
Dhat_RSimCo,X,_ = DictUpdate03(Y,Dhat_RSimCo,X,IPara)
287+
288+
count_RSimCo = FindDistanceBetweenDictionaries(D,Dhat_RSimCo)
289+
count_success_RSimCo[j,epoch] = count_RSimCo
290+
e_RSimCo[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_RSimCo,X),'fro')**2
291+
#############################################
292+
# Applying Primitive SimCo Algorithm #
293+
#############################################
294+
if(FlagPSimCo==True):
295+
IPara.mu = 0
296+
X = omp(D,Y,sparsity)
297+
for j in range(iterations):
298+
# X = omp(Dhat_PSimCo,Y,sparsity)
299+
Dhat_PSimCo,X,_ = DictUpdate03(Y,Dhat_PSimCo,X,IPara)
300+
301+
count_PSimCo = FindDistanceBetweenDictionaries(D,Dhat_PSimCo)
302+
count_success_PSimCo[j,epoch] = count_PSimCo
303+
e_PSimCo[j,epoch] = np.linalg.norm(Y-np.dot(Dhat_PSimCo,X),'fro')**2
304+
#############################################
305+
#############################################
306+
print 'epoch: ',epoch,'completed'
307+
308+
plt.close('all')
309+
if FlagPGD==True: plt.plot(np.sum(count_success,axis=1)/epochs,'b',label = 'PGD')
310+
if FlagPGDMom==True: plt.plot(np.sum(count_success_momen,axis=1)/epochs,'r',label = 'PGD_Momentum')
311+
if FlagMOD==True: plt.plot(np.sum(count_success_MOD,axis=1)/epochs,'g',label = 'MOD')
312+
if FlagKSVD==True: plt.plot(np.sum(count_success_KSVD,axis=1)/epochs,'y',label = 'KSVD')
313+
if FlagRSimCo==True: plt.plot(np.sum(count_success_RSimCo,axis=1)/epochs,'m',label = 'RSimCo')
314+
if FlagPSimCo==True: plt.plot(np.sum(count_success_PSimCo,axis=1)/epochs,'c',label = 'PSimCo')
315+
if FlagGDBTLS==True: plt.plot(np.sum(count_success_GDBTLS,axis=1)/epochs,':',label = 'GDBTLS')
316+
if FlagRGDBTLS==True: plt.plot(np.sum(count_success_RGDBTLS,axis=1)/epochs,'--',label = 'R_GDBTLS')
317+
318+
plt.legend()
319+
plt.xlabel('iteration number')
320+
plt.ylabel('Success Counts in iteration')
321+
plt.title('Dictionary Learning Algorithms applied on Syhthetic data')
322+
323+
plt.figure()
324+
if FlagPGD==True: plt.plot(np.sum(e,axis=1)/epochs,'b',label = 'PGD')
325+
if FlagPGDMom==True: plt.plot(np.sum(e_momen,axis=1)/epochs,'r',label = 'PGD_Momentum')
326+
if FlagMOD==True: plt.plot(np.sum(e_MOD,axis=1)/epochs,'g',label = 'MOD')
327+
if FlagKSVD==True: plt.plot(np.sum(e_KSVD,axis=1)/epochs,'y',label = 'KSVD')
328+
if FlagRSimCo==True: plt.plot(np.sum(e_RSimCo,axis=1)/epochs,'m',label = 'RSimCo')
329+
if FlagPSimCo==True: plt.plot(np.sum(e_PSimCo,axis=1)/epochs,'c',label = 'PSimCo')
330+
if FlagGDBTLS==True: plt.plot(np.sum(e_GDBTLS,axis=1)/epochs,':',label = 'GDBTLS')
331+
if FlagRGDBTLS==True: plt.plot(np.sum(e_RGDBTLS,axis=1)/epochs,'--',label = 'R_GDBTLS')
332+
333+
plt.legend()
334+
plt.xlabel('iteration number')
335+
plt.ylabel('Error: Sum of squares')
336+
337+
toc = time.time()
338+
print 'Total Time Taken by code: ','%.2f' %((toc-tic)/60.0),'min'

0 commit comments

Comments
 (0)