|
| 1 | +# SVS index serializer for benchmarks. |
| 2 | +# Serializes datasets to SVS index format for use by C++ and Python benchmarks. |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import VecSim |
| 6 | +import h5py |
| 7 | +import os |
| 8 | + |
| 9 | +# Determine working directory |
| 10 | +location = os.path.abspath('.') |
| 11 | +if location.endswith('/data'): |
| 12 | + location = os.path.join(location, '') |
| 13 | +elif location.endswith('/VectorSimilarity'): |
| 14 | + location = os.path.join(location, 'tests', 'benchmark', 'data', '') |
| 15 | +else: |
| 16 | + print('unexpected location:', location) |
| 17 | + print('expected to be in `./VectorSimilarity/tests/benchmark/data` or `./VectorSimilarity`') |
| 18 | + exit(1) |
| 19 | +print('working at:', location) |
| 20 | + |
| 21 | +DEFAULT_FILES = [ |
| 22 | + { |
| 23 | + 'filename': 'dbpedia-768', |
| 24 | + 'nickname': 'dbpedia', |
| 25 | + 'dim': 768, |
| 26 | + 'metric': VecSim.VecSimMetric_Cosine, |
| 27 | + 'hdf5_file': 'dbpedia-cosine-dim768.hdf5', |
| 28 | + }, |
| 29 | + { |
| 30 | + 'filename': 'fashion_images_multi_value', |
| 31 | + 'nickname': 'fashion_images_multi_value', |
| 32 | + 'hdf5_file': 'fashion_images_multi_value-cosine-dim512.hdf5', |
| 33 | + 'dim': 512, |
| 34 | + 'metric': VecSim.VecSimMetric_Cosine, |
| 35 | + 'multi': True, |
| 36 | + }, |
| 37 | +] |
| 38 | + |
| 39 | +TYPES_ATTR = { |
| 40 | + VecSim.VecSimType_FLOAT32: {"size_in_bytes": 4, "vector_type": np.float32}, |
| 41 | +} |
| 42 | + |
| 43 | +def load_vectors_and_labels_from_hdf5(input_file): |
| 44 | + """ |
| 45 | + Load vectors and labels from an HDF5 file. |
| 46 | + Returns: (vectors, labels) numpy arrays, or (None, None) on failure. |
| 47 | + """ |
| 48 | + try: |
| 49 | + with h5py.File(input_file, 'r') as f: |
| 50 | + vectors = f['vectors'][:] |
| 51 | + labels = f['labels'][:] |
| 52 | + |
| 53 | + print(f"Loaded {input_file}: vectors {vectors.shape}, labels {labels.shape}") |
| 54 | + return vectors, labels |
| 55 | + |
| 56 | + except Exception as e: |
| 57 | + print(f"Error loading HDF5 file: {e}") |
| 58 | + return None, None |
| 59 | + |
| 60 | +def serialize(files=DEFAULT_FILES): |
| 61 | + for file in files: |
| 62 | + filename = file['filename'] |
| 63 | + nickname = file.get('nickname', filename) |
| 64 | + dim = file.get('dim', None) |
| 65 | + metric = file['metric'] |
| 66 | + is_multi = file.get('multi', False) |
| 67 | + vec_type = file.get('type', VecSim.VecSimType_FLOAT32) |
| 68 | + |
| 69 | + # Load vectors/labels |
| 70 | + hdf5_file = file.get('hdf5_file', f"{filename}.hdf5") |
| 71 | + hdf5_path = os.path.join(location, hdf5_file) |
| 72 | + print(f"Loading vectors from {hdf5_path}") |
| 73 | + |
| 74 | + if is_multi: |
| 75 | + if vectors.ndim == 3: |
| 76 | + vectors = vectors.reshape(-1, vectors.shape[-1]) |
| 77 | + labels = np.repeat(labels, vectors.shape[0] // labels.shape[0]) |
| 78 | + |
| 79 | + if vectors is None or labels is None: |
| 80 | + print(f"Failed to load data from {hdf5_path}, skipping...") |
| 81 | + continue |
| 82 | + |
| 83 | + # Handle shape (N, 1, D) -> (N, D) |
| 84 | + if not is_multi: |
| 85 | + if vectors.ndim == 3 and vectors.shape[1] == 1: |
| 86 | + vectors = vectors.squeeze(axis=1) |
| 87 | + elif vectors.ndim != 2: |
| 88 | + print(f"Error: Expected 2D vectors, got shape {vectors.shape}") |
| 89 | + continue |
| 90 | + |
| 91 | + # Update dimension if not specified |
| 92 | + if dim is None: |
| 93 | + dim = vectors.shape[1] |
| 94 | + print(f"Auto-detected dimension: {dim}") |
| 95 | + |
| 96 | + assert dim == vectors.shape[1], f"Dimension mismatch: {dim} != {vectors.shape[1]}" |
| 97 | + |
| 98 | + # Create SVS parameters |
| 99 | + bits_to_str = { |
| 100 | + VecSim.VecSimSvsQuant_NONE: '_none', |
| 101 | + VecSim.VecSimSvsQuant_8: '_8', |
| 102 | + } |
| 103 | + for bits in [VecSim.VecSimSvsQuant_8, VecSim.VecSimSvsQuant_NONE]: |
| 104 | + svs_params = VecSim.SVSParams() |
| 105 | + svs_params.type = vec_type |
| 106 | + svs_params.dim = dim |
| 107 | + svs_params.metric = metric |
| 108 | + svs_params.graph_max_degree = file.get('graph_max_degree', 128) |
| 109 | + svs_params.construction_window_size = file.get('construction_window_size', 512) |
| 110 | + svs_params.quantBits = bits |
| 111 | + svs_params.multi = is_multi |
| 112 | + |
| 113 | + |
| 114 | + print(f"Creating SVS index for {filename} (dim={dim}, metric={metric})") |
| 115 | + if vectors.dtype != np.float32: |
| 116 | + print(f"Converting vectors from {vectors.dtype} to float32") |
| 117 | + vectors = vectors.astype(np.float32) |
| 118 | + if labels.dtype != np.uint64: |
| 119 | + print(f"Converting labels from {labels.dtype} to uint64") |
| 120 | + labels = labels.astype(np.uint64) |
| 121 | + |
| 122 | + # Create index and add vectors |
| 123 | + svs_index = VecSim.SVSIndex(svs_params) |
| 124 | + print(f"Adding {len(vectors)} vectors...") |
| 125 | + svs_index.add_vector_parallel(vectors, labels) |
| 126 | + |
| 127 | + # Save index |
| 128 | + dir = os.path.join(location, nickname + '_svs'+bits_to_str[bits]) |
| 129 | + os.makedirs(dir, exist_ok=True) |
| 130 | + svs_index.save_index(dir) |
| 131 | + print(f"Index saved to {dir}") |
| 132 | + print(f"Final index size: {svs_index.index_size()}") |
| 133 | + |
| 134 | + # Verify |
| 135 | + print("Verifying saved index...") |
| 136 | + svs_index_verify = VecSim.SVSIndex(svs_params) |
| 137 | + svs_index_verify.load_index(dir) |
| 138 | + svs_index_verify.check_integrity() |
| 139 | + print(f"Verified index size: {svs_index_verify.index_size()}") |
| 140 | + assert svs_index_verify.get_labels_count() == labels.max() + 1 |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == '__main__': |
| 144 | + serialize() |
0 commit comments