Skip to content

Commit b9d5a32

Browse files
authored
Add files via upload
1 parent 12eef76 commit b9d5a32

8 files changed

+2225
-0
lines changed

Diff for: CTF-SSCL-UP-knn.py

+800
Large diffs are not rendered by default.

Diff for: Overall.png

295 KB
Loading

Diff for: README.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# CTF-SSCL
2+
This repo is the implementation of the following paper:
3+
4+
**CTF-SSCL: CNN-Transformer for Few-shot Hyperspectral Image Classification Assisted by Semisupervised Contrastive Learning** (TGRS 2024), [[paper]](DOI:10.1109/TGRS.2024.3465225)
5+
6+
## Abstract
7+
Few-shot learning (FSL) has rapidly advanced in the hyperspectral image (HSI) classification, potentially reducing the need for laborious and expensive labeled data collection. Due to the limited receptive field, the convolutional neural network (CNN) struggles to capture long-range dependencies for extracting global features. Additionally, the transformer focuses on global correlation while overlooking the effective representation of local spatial and spectral features. Moreover, contrastive learning has emerged as a powerful technique for improving consistency across different augmented views of samples of the same category.
8+
To this end, we devise a novel CNN-Transformer network for few-shot HSI classification with semisupervised contrastive learning (CTF-SSCL) to boost the classification performance. Specifically, the cascaded CNN-Transformer incorporates a lightweight spatial-spectral interactive convolution module (LSSICM) and a multi-scale transformer (MSFormer) to exploit local features from submaps and global information from the entire patch. Subsequently, the semisupervised contrastive loss, comprising unsupervised and supervised components, serves as an auxiliary to optimize the model with the classification loss. Wherein, recognizing the unified spectral-spatial information in HSI, we propose a spectral feature shift strategy (SFSS) to create sample pairs for the unsupervised contrastive learning, utilizing unsupervised contrastive loss among groups of samples with identical labels. Extensive experiments on four standard benchmarks demonstrate the effectiveness of the proposed CTF-SSCL with varying amounts of labeled samples. The code will be available online at https://github.com/B-Xi/CTF-SSCL.
9+
10+
## Training and Test Process
11+
1. Prepare the training and test data as operated in the paper.
12+
2. Run the 'CTF-SSCL-UP-knn.py' to reproduce the CTF-SSCL results on Pavia University data set.
13+

Diff for: gscvit.py

+358
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
import random
2+
from functools import partial
3+
import torch
4+
from torch import nn, einsum
5+
from einops import rearrange, repeat
6+
from einops.layers.torch import Rearrange, Reduce
7+
8+
from timm.models.vision_transformer import _cfg
9+
10+
11+
def cast_tuple(val, length=1):
12+
return val if isinstance(val, tuple) else ((val,) * length)
13+
14+
15+
class ChannelAdjustmentLayer1(nn.Module):
16+
def __init__(self, target_channels=256):
17+
super(ChannelAdjustmentLayer1, self).__init__()
18+
self.target_channels = target_channels
19+
20+
def forward(self, x):
21+
B, C, H, W = x.size()
22+
23+
if C == self.target_channels:
24+
return x
25+
26+
if C < self.target_channels:
27+
# 逐个通道复制,放到被复制通道的后面
28+
num_channels_to_copy = self.target_channels - C
29+
for i in range(num_channels_to_copy):
30+
channel_to_copy = torch.randint(0, C, (1,))
31+
x = torch.cat([x, x[:, channel_to_copy, :, :]], dim=1)
32+
33+
else:
34+
# 逐个通道删除
35+
num_channels_to_remove = C - self.target_channels
36+
for i in range(num_channels_to_remove):
37+
rand=torch.randint(0,x.shape[1],(1,))
38+
x = torch.cat([x[:, :rand, :, :], x[:, rand+1:, :, :]], dim=1)
39+
return x
40+
41+
42+
# 通道校准策略2
43+
class ChannelAdjustmentLayer2(nn.Module):
44+
def __init__(self, target_channels=256):
45+
super(ChannelAdjustmentLayer2, self).__init__()
46+
self.target_channels = target_channels
47+
48+
def forward(self, x):
49+
B, C, H, W = x.size()
50+
51+
if C == self.target_channels:
52+
return x
53+
54+
if C < self.target_channels:
55+
# 计算需要扩展的通道数
56+
num_channels_to_expand = self.target_channels - C
57+
# 计算每一端需要扩展的通道数
58+
channels_to_expand_per_side = num_channels_to_expand // 2
59+
60+
# 两端均匀镜像扩展
61+
x = torch.cat([x[:, :channels_to_expand_per_side+1, :, :].flip(dims=(1,)),
62+
x,
63+
x[:, -channels_to_expand_per_side:, :, :].flip(dims=(1,))], dim=1)
64+
65+
else:
66+
# 计算需要删除的通道数
67+
num_channels_to_remove = C - self.target_channels
68+
# 计算每一端需要删除的通道数
69+
channels_to_remove_per_side = num_channels_to_remove // 2
70+
71+
# 两端均匀镜像删除
72+
x = x[:, channels_to_remove_per_side:-channels_to_remove_per_side, :, :]
73+
74+
return x[:, :self.target_channels, :, :]
75+
76+
77+
# 通道正则化
78+
class ChanLayerNorm(nn.Module):
79+
def __init__(self, dim, eps=1e-5):
80+
super().__init__()
81+
self.eps = eps
82+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
83+
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
84+
85+
def forward(self, x):
86+
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
87+
mean = torch.mean(x, dim=1, keepdim=True)
88+
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
89+
90+
91+
class PreNorm(nn.Module):
92+
def __init__(self, dim, fn):
93+
super().__init__()
94+
self.norm = ChanLayerNorm(dim)
95+
self.fn = fn
96+
97+
def forward(self, x):
98+
return self.fn(self.norm(x))
99+
100+
101+
# 通道校准模块
102+
class SpectralCalibration(nn.Module):
103+
def __init__(self, dim_in, dim_out):
104+
super().__init__()
105+
self.conv = nn.Conv2d(dim_in, dim_out, 1)
106+
self.bn = nn.BatchNorm2d(dim_out)
107+
self.relu = nn.ReLU(inplace=True)
108+
109+
def forward(self, x):
110+
x = self.conv(x)
111+
x = self.bn(x)
112+
x = self.relu(x)
113+
return x
114+
115+
116+
class GSC(nn.Module):
117+
def __init__(self, dim_in, dim_out, padding=1, num_groups=8):
118+
super().__init__()
119+
self.dim_out=dim_out
120+
self.gpwc = nn.Conv2d(dim_in, dim_out, groups=num_groups, kernel_size=1)
121+
self.dwc1 = nn.Conv2d(dim_out//2, dim_out//2, groups=dim_out//2,kernel_size=1)
122+
self.dwc2 = nn.Conv2d(dim_out//2, dim_out//2, groups=dim_out//2,padding=1, kernel_size=3)
123+
self.dwc3 = nn.Conv2d(dim_out//4, dim_out//4, groups=dim_out//4,padding=2,kernel_size=5)
124+
self.dwc4 = nn.Conv2d(dim_out//4, dim_out//4, groups=dim_out//4,padding=3, kernel_size=7)
125+
self.bn = nn.BatchNorm2d(dim_out)
126+
self.relu = nn.ReLU(inplace=True)
127+
128+
def forward(self, x):
129+
x=self.gpwc(x)
130+
x1=x[:,:self.dim_out//2,:,:]
131+
x2=x[:,self.dim_out//2:,:,:]
132+
x3=self.dwc2(x1)
133+
x4=self.dwc2(x2)
134+
x5=x3+x4/10
135+
x6=x4+x3/10
136+
x=torch.cat((x5,x6),dim=1)+x
137+
return self.relu(self.bn(x))
138+
139+
140+
class GSSA(nn.Module):
141+
def __init__(
142+
self,
143+
dim,
144+
heads=8,
145+
dim_head=16,
146+
dropout=0.,
147+
group_spatial_size=3
148+
):
149+
super().__init__()
150+
self.heads = heads
151+
self.scale = dim_head ** -0.5
152+
self.group_spatial_size = group_spatial_size
153+
inner_dim = dim_head * heads
154+
155+
self.attend = nn.Sequential(
156+
nn.Softmax(dim=-1),
157+
nn.Dropout(dropout)
158+
)
159+
160+
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias=False)
161+
self.to_qkv1 = nn.Conv1d(128, 16, 1, bias=False)
162+
self.group_tokens = nn.Parameter(torch.randn(dim))
163+
dim_out=128
164+
self.gc = nn.Conv2d(dim_out, dim_out, kernel_size=7, groups=dim_out, stride=1)
165+
166+
self.group_tokens_to_qk = nn.Sequential(
167+
nn.LayerNorm(dim_head),
168+
nn.GELU(),
169+
Rearrange('b h n c -> b (h c) n'),
170+
nn.Conv1d(inner_dim, inner_dim * 2, 1),
171+
Rearrange('b (h c) n -> b h n c', h=heads),
172+
)
173+
174+
self.group_attend = nn.Sequential(
175+
nn.Softmax(dim=-1),
176+
nn.Dropout(dropout)
177+
)
178+
179+
self.to_out = nn.Sequential(
180+
nn.Conv2d(inner_dim, dim, 1),
181+
nn.Dropout(dropout)
182+
)
183+
184+
def forward(self, x):
185+
186+
batch, height, width, heads, gss = x.shape[0], *x.shape[-2:], self.heads, self.group_spatial_size
187+
assert (height % gss) == 0 and (
188+
width % gss) == 0, f'height {height} and width {width} must be divisible by group spatial size {gss}'
189+
num_groups = (height // gss) * (width // gss)
190+
w=self.gc(x)
191+
192+
w= rearrange(w, 'b c h w -> (b h w) c 1')
193+
194+
x = rearrange(x, 'b c (h g1) (w g2) -> (b h w) c (g1 g2)', g1=gss, g2=gss)
195+
196+
#w = repeat(self.group_tokens, 'c -> b c 1', b=x.shape[0])
197+
198+
x = torch.cat((w, x), dim=-1)
199+
q, k, v = self.to_qkv(x).chunk(3, dim=1)
200+
201+
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h=heads), (q, k, v))
202+
203+
q = q * self.scale
204+
205+
dots = einsum('b h i d, b h j d -> b h i j', q, k)
206+
207+
attn = self.attend(dots)
208+
209+
out = torch.matmul(attn, v)
210+
group_tokens, grouped_fmaps = out[:, :, 0], out[:, :, 1:]
211+
212+
if num_groups == 1:
213+
fmap = rearrange(grouped_fmaps, '(b x y) h (g1 g2) d -> b (h d) (x g1) (y g2)', x=height // gss,
214+
y=width // gss, g=gss, g2=gss)
215+
return self.to_out(fmap)
216+
217+
#group_tokens=group_tokens+w1
218+
219+
group_tokens = rearrange(group_tokens, '(b x y) h d -> b h (x y) d', x=height // gss, y=width // gss)
220+
221+
grouped_fmaps = rearrange(grouped_fmaps, '(b x y) h n d -> b h (x y) n d', x=height // gss, y=width // gss)
222+
223+
w_q, w_k = self.group_tokens_to_qk(group_tokens).chunk(2, dim=-1)
224+
225+
w_q = w_q * self.scale
226+
227+
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
228+
229+
w_attn = self.group_attend(w_dots)
230+
231+
aggregated_grouped_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, grouped_fmaps)
232+
233+
fmap = rearrange(aggregated_grouped_fmap, 'b h (x y) (g1 g2) d -> b (h d) (x g1) (y g2)', x=height // gss,
234+
y=width // gss, g1=gss, g2=gss)
235+
return self.to_out(fmap)
236+
237+
238+
class Transformer(nn.Module):
239+
def __init__(
240+
self,
241+
dim,
242+
depth,
243+
dim_head=16,
244+
heads=8,
245+
dropout=0.,
246+
norm_output=True,
247+
groupsize=4
248+
):
249+
super().__init__()
250+
self.layers = nn.ModuleList([])
251+
252+
for ind in range(depth):
253+
self.layers.append(
254+
PreNorm(dim, GSSA(dim, group_spatial_size=groupsize, heads=heads, dim_head=dim_head, dropout=dropout))
255+
)
256+
257+
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
258+
259+
def forward(self, x):
260+
for attn in self.layers:
261+
x = attn(x)
262+
263+
return self.norm(x)
264+
265+
266+
class GSCViT(nn.Module):
267+
def __init__(
268+
self,
269+
*,
270+
num_classes,
271+
depth,
272+
heads,
273+
group_spatial_size,
274+
channels=200,
275+
dropout=0.1,
276+
padding,
277+
dims=(256, 128, 64, 32),
278+
num_groups=[16,16,16]
279+
):
280+
super().__init__()
281+
num_stages = 1#len(depth)
282+
283+
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
284+
hyperparams_per_stage = [heads]
285+
hyperparams_per_stage = list(map(partial(cast_tuple, length=num_stages), hyperparams_per_stage))
286+
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
287+
self.sc = SpectralCalibration(channels, 256)
288+
self.bn_1 = nn.BatchNorm2d(256)
289+
self.relu_1 = nn.ReLU(inplace=True)
290+
self.layers_trans = nn.ModuleList([])
291+
is_last=1
292+
layer_dim_in=256
293+
layer_dim=128
294+
p=1
295+
num_group=16
296+
layer_depth=1
297+
ind=0
298+
layer_heads=1
299+
self.layers_trans.append(nn.ModuleList([
300+
GSC(layer_dim_in, layer_dim, p, num_group),
301+
Transformer(dim=int(layer_dim), depth=layer_depth, heads=layer_heads,
302+
groupsize=group_spatial_size[ind], dropout=dropout, norm_output=not is_last),
303+
nn.BatchNorm2d(layer_dim),
304+
nn.ReLU(inplace=True),
305+
nn.Conv2d(layer_dim,layer_dim,1),
306+
307+
]))
308+
309+
self.conv_last = nn.Conv2d(dims[-1], 2 * dims[-1], 3)
310+
311+
self.mlp_head = nn.Sequential(
312+
Reduce('b d h w -> b d', 'mean'),
313+
nn.LayerNorm(dims[-1]),
314+
nn.Linear(dims[-1], num_classes)
315+
)
316+
317+
def forward(self, x):
318+
x = x.squeeze(dim=1)
319+
x = self.sc(x)
320+
x = self.bn_1(x)
321+
x = self.relu_1(x)
322+
323+
for peg, transformer, bn, relu, pw in self.layers_trans:
324+
x = peg(x)
325+
y = x
326+
x = transformer(x)
327+
x = pw(x) + y
328+
x = bn(x)
329+
x = relu(x)
330+
return self.mlp_head(x)
331+
332+
333+
def gscvit(dataset):
334+
model = None
335+
if dataset == 'hy':
336+
model = GSCViT(
337+
num_classes=32,
338+
channels=100,
339+
heads=(1),#(1,1,1)
340+
depth=(1),#(1,1,1)
341+
group_spatial_size=[3, 3, 3],
342+
dropout=0.1,
343+
padding=[1, 1, 1],
344+
dims = (256, 128),
345+
num_groups=[16, 16, 16],
346+
)
347+
return model
348+
349+
350+
351+
if __name__ == '__main__':
352+
img = torch.randn(9, 100, 9, 9)
353+
print("input shape:", img.shape)
354+
net = gscvit(dataset='hy')
355+
net.default_cfg = _cfg()
356+
print("output shape:", net(img).shape)
357+
358+

0 commit comments

Comments
 (0)