-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquality.py
157 lines (141 loc) · 5.44 KB
/
quality.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import itertools
import cv2
from datasets import load_dataset
import numpy as np
from src import opencv
from src import skimage
from src import rust
from skimage.metrics import (
structural_similarity,
peak_signal_noise_ratio,
mean_squared_error,
)
from PIL import Image
import imagehash
import polars as pl
from rich.progress import track
from utils import downsample_image, get_public_functions
def convert_to_grayscale(image):
if image.ndim == 3:
if image.shape[-1] == 1:
image = np.squeeze(image, axis=-1)
elif image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
if (
image.dtype == np.float32
or image.dtype == np.float64
or image.dtype == np.float16
):
image = (image * 255).astype(np.uint8)
return image
def calculate_metrics(original_image, interpolated_image):
original_image_np = np.array(original_image)
interpolated_image_np = np.array(interpolated_image)
downsampled_image_np = downsample_image(
interpolated_image_np, original_image_np.shape
)
original_gray = convert_to_grayscale(original_image_np)
downsampled_gray = convert_to_grayscale(downsampled_image_np)
mse = mean_squared_error(original_image_np, downsampled_image_np)
psnr = peak_signal_noise_ratio(original_image_np, downsampled_image_np)
ssim = structural_similarity(original_gray, downsampled_gray)
hash_a = imagehash.average_hash(original_image) - imagehash.average_hash(
interpolated_image
)
hash_p = imagehash.phash(original_image) - imagehash.phash(interpolated_image)
hash_d = imagehash.dhash(original_image) - imagehash.dhash(interpolated_image)
hash_w = imagehash.whash(original_image) - imagehash.whash(interpolated_image)
metrics_dict = {
"mse": mse,
"psnr": psnr,
"ssim": ssim,
"average_hash_diff": int(hash_a),
"phash_diff": int(hash_p),
"dhash_diff": int(hash_d),
"whash_diff": int(hash_w),
}
return metrics_dict
if __name__ == "__main__":
datasets_config = [
{
"name": "uoft-cs/cifar10",
"image_col": "img",
"label_col": "label",
},
{
"name": "AI-Lab-Makerere/beans",
"image_col": "image",
"label_col": "labels",
},
{
"name": "ylecun/mnist",
"image_col": "image",
"label_col": "label",
},
{
"name": "blanchon/UC_Merced",
"image_col": "image",
"label_col": "label",
},
]
modules = [opencv, skimage, rust]
functions = list(
itertools.chain.from_iterable(
(get_public_functions(module) for module in modules)
)
)
functions = [
(f"{func.__module__.split('.')[1]}.{func.__name__}", func) for func in functions
]
scale_factors = [2**i for i in range(1, 3)]
test_data_records = [] # 存储测试数据的列表
for dataset_info in datasets_config:
dataset_name = dataset_info["name"]
image_col = dataset_info["image_col"]
label_col = dataset_info["label_col"]
try:
dataset = load_dataset(dataset_name, split="train")
except Exception as e:
print(f"Error loading dataset {dataset_name}: {e}")
continue
label_names = dataset.features[label_col].names
for example in track(
dataset.select(range(min(len(dataset), 5000))), description=dataset_name
):
original_image = example[image_col]
if not isinstance(original_image, Image.Image):
original_image = Image.fromarray(original_image)
original_image_np = np.array(original_image)
if original_image_np.ndim == 2:
original_image_np = np.expand_dims(original_image_np, axis=2)
label_name = (
label_names[example[label_col]] if label_col in example else "unknown"
)
for interp_name, interp_func in functions:
for scale_factor in scale_factors:
interpolated_image_np = interp_func(original_image_np, scale_factor)
if (
interpolated_image_np.ndim == 3
and interpolated_image_np.shape[-1] == 1
):
interpolated_image_np = np.squeeze(
interpolated_image_np, axis=2
)
interpolated_image = Image.fromarray(interpolated_image_np)
metrics = calculate_metrics(original_image, interpolated_image)
record = {
"dataset_name": dataset_name,
"label_name": label_name,
"interp_algorithm": interp_name,
"scale_factor": scale_factor,
"mse": metrics["mse"],
"psnr": metrics["psnr"],
"ssim": metrics["ssim"],
"average_hash_diff": metrics["average_hash_diff"],
"phash_diff": metrics["phash_diff"],
"dhash_diff": metrics["dhash_diff"],
"whash_diff": metrics["whash_diff"],
}
test_data_records.append(record)
df_results = pl.DataFrame(test_data_records)
df_results.write_json("quality.json")