-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdict.py
106 lines (95 loc) · 2.9 KB
/
dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from turtle import update
from cv2 import reduce
import torch
class Dictionary:
def __init__(self, size=(200,8100),l=0.5, gamma = 0.005) -> None:
self.D = torch.rand(size)
self.l = l
self.gamma = gamma
def getError(self, X:torch.Tensor, A:torch.Tensor)->float:
assert(X.shape[0] == A.shape[0])
assert(A.shape[1] == self.D.shape[0])
assert(self.D.shape[1] == X.shape[1])
# X \in s*f
# D \in d*f
# A \in s*d
segLenth = X.shape[0]
Tf = (X - [email protected])
fx = 0.5*(1/segLenth)*torch.sum(torch.linalg.norm(Tf, axis = 1)*torch.linalg.norm(Tf, axis = 1))
gx = self.l * torch.sum(torch.linalg.norm(A, axis = 1)*torch.linalg.norm(A, axis = 1))
return fx+gx
def reduceDict(self, X, lr=1.0):
DT = self.D.transpose(0,1)
s = X.shape[0]
d = self.D.shape[0]
A = X @ DT @ torch.linalg.inv(2 * self.l * s * torch.eye(d) + (self.D @ DT) )
return A, self.getError(X, A)
def updateDict(self, X, lr=1.0):
m = 10
s = X[0][0].shape[0]
d = self.D.shape[0]
temax = 0.0
temaa = self.gamma*2*torch.eye(d)
for x, a in X[-10:]:
A,_= self.reduceDict(x)
AT = A.transpose(0,1)
temax += 1/float(m*s)*AT@x
temaa += 1/float(m*s)*AT@A
self.D = torch.linalg.inv(temaa)@temax
def initDict(self, X, lr=1.0):
A_set = []
m = len(X)
s = X[0].shape[0]
d = self.D.shape[0]
for x in X:
A, error = self.reduceDict(x,0.1)
A_set.append(A)
temax = 0.0
temaa = self.gamma*2*torch.eye(d)
for (i, x) in enumerate(X):
A = A_set[i]
AT = A.transpose(0,1)
temax += 1/float(m*s)*AT@x
temaa += 1/float(m*s)*AT@A
self.D = torch.linalg.inv(temaa)@temax
if __name__ == "__main__":
data = []
# N*S*F N*50*8100
size = (50,8100)
for i in range(100):
data.append(torch.ones(size))
it = torch.zeros((50,8100))
it = it + torch.sin(torch.arange(0,8100))
for i in range(100):
data.append(it)
for i in range(100):
data.append(torch.ones(size))
print("get the data...")
startNum = 30
std =2
X_ = []
myDict = Dictionary()
Chosen = []
res = []
cnt = 0
for i,d in enumerate(data):
if i < startNum :
X_.append(d)
Chosen.append(i)
continue
if i == startNum:
print("init...")
myDict.initDict(X_, 1)
print(i)
A, error = myDict.reduceDict(d, 0.1)
print(error)
if error < std:
cnt+=1
continue
else:
print(f"update:{i}")
Chosen.append(i)
res.append((d,A))
myDict.updateDict(res, 0.1)
print(Chosen)
print(cnt)