-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_cifar10_processing.py .py
More file actions
163 lines (132 loc) · 6.15 KB
/
example_cifar10_processing.py .py
File metadata and controls
163 lines (132 loc) · 6.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Example LBP transform on CIFAR-10 dataset using LbpPytorch module
"""
from lib.lbplib import lbp_py, LbpPytorch
import numpy as np
import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Compute LBP using Python operations (unchanged)
def py_extract_lbp_rgb(x_train):
"""Extract LBP features using NumPy/Python for each RGB channel"""
[N, Rows, Cols, Channels] = x_train.shape
x_train_lbp = np.zeros(shape=(N, Rows, Cols, Channels), dtype='uint8')
for i in range(N):
x_train_lbp[i, :, :, 0] = lbp_py(x_train[i, :, :, 0])
x_train_lbp[i, :, :, 1] = lbp_py(x_train[i, :, :, 1])
x_train_lbp[i, :, :, 2] = lbp_py(x_train[i, :, :, 2])
return x_train_lbp
# Compute LBP using PyTorch LbpPytorch module (revised)
def torch_extract_lbp_rgb_module(x_train_np):
"""Extract LBP features using LbpPytorch module for all channels at once"""
# Create LbpPytorch module instance
lbp_module = LbpPytorch(input_format='NCHW').to(device)
lbp_module.eval() # Set to evaluation mode
# Convert numpy array from (N, H, W, C) to (N, C, H, W) for PyTorch
x_train_nchw = np.transpose(x_train_np, (0, 3, 1, 2))
x_train_torch = torch.from_numpy(x_train_nchw).to(device)
# Process all images and channels in a single forward pass
with torch.no_grad():
x_train_lbp_torch = lbp_module(x_train_torch)
# Convert back to numpy and transpose to (N, H, W, C)
x_train_lbp_nchw = x_train_lbp_torch.cpu().numpy()
x_train_lbp_nhwc = np.transpose(x_train_lbp_nchw, (0, 2, 3, 1))
# Convert from normalized [0, 1] to uint8 [0, 255]
x_train_lbp_uint8 = (x_train_lbp_nhwc * 255).astype('uint8')
return x_train_lbp_uint8
# Alternative: Process each channel separately (similar to original but with module)
def torch_extract_lbp_rgb_module_separate(x_train_np):
"""Extract LBP features using LbpPytorch module, processing each channel separately"""
[N, Rows, Cols, Channels] = x_train_np.shape
x_train_lbp = np.zeros(shape=(N, Rows, Cols, Channels), dtype='uint8')
# Create LbpPytorch module instance
lbp_module = LbpPytorch(input_format='NCHW').to(device)
lbp_module.eval()
with torch.no_grad():
for c in range(Channels):
# Extract single channel and add channel dimension
# Shape: (N, H, W) -> (N, 1, H, W)
channel_data = torch.from_numpy(x_train_np[:, :, :, c]).unsqueeze(1).to(device)
# Apply LBP transform
lbp_result = lbp_module(channel_data)
# Convert back to numpy and squeeze channel dimension
# Shape: (N, 1, H, W) -> (N, H, W)
lbp_channel = lbp_result.squeeze(1).cpu().numpy()
# Convert from [0, 1] to [0, 255] and store
x_train_lbp[:, :, :, c] = (lbp_channel * 255).astype('uint8')
return x_train_lbp
# Load CIFAR-10 dataset
print("Loading CIFAR-10 dataset...")
trainset = torchvision.datasets.CIFAR10(root='./data', download=True)
# Use first 200 images for the test
x_train = trainset.data[:200, :, :, :]
print(f"Test data shape: {x_train.shape}")
print("-" * 50)
# 1. Extract LBP features using PyTorch LbpPytorch module (all channels at once)
print("Processing with PyTorch LbpPytorch module (batch processing)...")
start_time = time.time()
x_train_lbp_pt_batch = torch_extract_lbp_rgb_module(x_train)
elapsed_pt_batch = time.time() - start_time
print(f'PyTorch module (batch) elapsed time: {elapsed_pt_batch:.4f} seconds')
# 2. Extract LBP features using PyTorch LbpPytorch module (channel by channel)
print("Processing with PyTorch LbpPytorch module (channel-wise)...")
start_time = time.time()
x_train_lbp_pt_separate = torch_extract_lbp_rgb_module_separate(x_train)
elapsed_pt_separate = time.time() - start_time
print(f'PyTorch module (channel-wise) elapsed time: {elapsed_pt_separate:.4f} seconds')
# 3. Extract LBP features using Python/NumPy
print("Processing with Python/NumPy...")
start_time = time.time()
x_train_lbp_py = py_extract_lbp_rgb(x_train)
elapsed_py = time.time() - start_time
print(f'Python/NumPy elapsed time: {elapsed_py:.4f} seconds')
print("-" * 50)
# Check errors between different methods
error_batch_vs_py = np.sum(np.abs(x_train_lbp_py.astype(int) - x_train_lbp_pt_batch.astype(int)))
error_separate_vs_py = np.sum(np.abs(x_train_lbp_py.astype(int) - x_train_lbp_pt_separate.astype(int)))
error_batch_vs_separate = np.sum(np.abs(x_train_lbp_pt_batch.astype(int) - x_train_lbp_pt_separate.astype(int)))
print("Error Analysis:")
print(f'Error (Python vs PyTorch batch): {error_batch_vs_py}')
print(f'Error (Python vs PyTorch channel-wise): {error_separate_vs_py}')
print(f'Error (PyTorch batch vs channel-wise): {error_batch_vs_separate}')
# Performance comparison
if elapsed_py > 0:
speedup_batch = elapsed_py / elapsed_pt_batch
speedup_separate = elapsed_py / elapsed_pt_separate
print(f'\nSpeedup (PyTorch batch vs Python): {speedup_batch:.2f}x')
print(f'Speedup (PyTorch channel-wise vs Python): {speedup_separate:.2f}x')
print("-" * 50)
# Example images---------------------------------------------------------------
# Input images
plt.figure(1)
figs, axes = plt.subplots(4, 6)
for i in range(4):
for j in range(6):
axes[i, j].imshow(x_train[i*6+j, :, :, :])
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
# LBP transformed images (using PyTorch batch processing)
plt.figure(2)
figs, axes = plt.subplots(4, 6)
for i in range(4):
for j in range(6):
axes[i, j].imshow(x_train_lbp_pt_batch[i*6+j, :, :, :])
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
# Optional: Additional figure showing Python/NumPy LBP for comparison
# plt.figure(3)
# figs, axes = plt.subplots(4, 6)
# for i in range(4):
# for j in range(6):
# axes[i, j].imshow(x_train_lbp_py[i*6+j, :, :, :])
# axes[i, j].set_xticks([])
# axes[i, j].set_yticks([])
# plt.show()
# print("\nTest completed successfully!")