Skip to content

Commit 3ce0c3b

Browse files
authored
Merge pull request #3 from quackscience/scan
Add SCAN, MGET support
2 parents 0e88222 + ba38301 commit 3ce0c3b

File tree

2 files changed

+151
-6
lines changed

2 files changed

+151
-6
lines changed

docs/README.md

+26
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,32 @@ SELECT redis_lpush('mylist', value, 'redis')
106106
FROM items;
107107
```
108108

109+
### Batch Operations
110+
```sql
111+
-- Get multiple keys at once
112+
SELECT redis_mget('key1,key2,key3', 'redis') as values;
113+
-- Returns comma-separated values for all keys
114+
115+
-- Scan keys matching a pattern
116+
SELECT redis_scan('0', 'user:*', 10, 'redis') as result;
117+
-- Returns: "cursor:key1,key2,key3" where cursor is the next position for scanning
118+
-- Use the returned cursor for the next scan until cursor is 0
119+
120+
-- Scan all keys matching a pattern
121+
WITH RECURSIVE scan(cursor, keys) AS (
122+
-- Initial scan
123+
SELECT split_part(redis_scan('0', 'user:*', 10, 'redis'), ':', 1),
124+
split_part(redis_scan('0', 'user:*', 10, 'redis'), ':', 2)
125+
UNION ALL
126+
-- Continue scanning until cursor is 0
127+
SELECT split_part(redis_scan(cursor, 'user:*', 10, 'redis'), ':', 1),
128+
split_part(redis_scan(cursor, 'user:*', 10, 'redis'), ':', 2)
129+
FROM scan
130+
WHERE cursor != '0'
131+
)
132+
SELECT keys FROM scan;
133+
```
134+
109135
## Error Handling
110136
The extension functions will throw exceptions with descriptive error messages when:
111137
- Redis secret is not found or invalid

src/redis_extension.cpp

+125-6
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ class RedisProtocol {
9191

9292
// Key scanning
9393
static std::string formatScan(const std::string& cursor, const std::string& pattern = "*", int64_t count = 10) {
94+
std::string cmd = "*6\r\n$4\r\nSCAN\r\n";
95+
cmd += "$" + std::to_string(cursor.length()) + "\r\n" + cursor + "\r\n";
96+
cmd += "$5\r\nMATCH\r\n";
97+
cmd += "$" + std::to_string(pattern.length()) + "\r\n" + pattern + "\r\n";
98+
cmd += "$5\r\nCOUNT\r\n";
9499
auto count_str = std::to_string(count);
95-
return "*6\r\n$4\r\nSCAN\r\n$" + std::to_string(cursor.length()) + "\r\n" + cursor +
96-
"\r\n$5\r\nMATCH\r\n$" + std::to_string(pattern.length()) + "\r\n" + pattern +
97-
"\r\n$5\r\nCOUNT\r\n$" + std::to_string(count_str.length()) + "\r\n" + count_str + "\r\n";
100+
cmd += "$" + std::to_string(count_str.length()) + "\r\n" + count_str + "\r\n";
101+
return cmd;
98102
}
99103

100104
static std::vector<std::string> parseArrayResponse(const std::string& response) {
@@ -120,6 +124,14 @@ class RedisProtocol {
120124
}
121125
return result;
122126
}
127+
128+
static std::string formatMGet(const std::vector<std::string>& keys) {
129+
std::string cmd = "*" + std::to_string(keys.size() + 1) + "\r\n$4\r\nMGET\r\n";
130+
for (const auto& key : keys) {
131+
cmd += "$" + std::to_string(key.length()) + "\r\n" + key + "\r\n";
132+
}
133+
return cmd;
134+
}
123135
};
124136

125137
// Redis connection class
@@ -392,8 +404,96 @@ static void RedisLRangeFunction(DataChunk &args, ExpressionState &state, Vector
392404
});
393405
}
394406

407+
static void RedisMGetFunction(DataChunk &args, ExpressionState &state, Vector &result) {
408+
auto &keys_list = args.data[0];
409+
auto &secret_vector = args.data[1];
410+
411+
UnaryExecutor::Execute<string_t, string_t>(
412+
keys_list, result, args.size(),
413+
[&](string_t keys_str) {
414+
try {
415+
// Split comma-separated keys
416+
std::vector<std::string> keys;
417+
std::string key_list = keys_str.GetString();
418+
size_t pos = 0;
419+
while ((pos = key_list.find(',')) != std::string::npos) {
420+
keys.push_back(key_list.substr(0, pos));
421+
key_list.erase(0, pos + 1);
422+
}
423+
if (!key_list.empty()) {
424+
keys.push_back(key_list);
425+
}
426+
427+
string host, port, password;
428+
if (!GetRedisSecret(state.GetContext(), secret_vector.GetValue(0).ToString(),
429+
host, port, password)) {
430+
throw InvalidInputException("Redis secret not found");
431+
}
432+
433+
auto conn = ConnectionPool::getInstance().getConnection(host, port, password);
434+
auto response = conn->execute(RedisProtocol::formatMGet(keys));
435+
auto values = RedisProtocol::parseArrayResponse(response);
436+
437+
// Join results with comma
438+
std::string joined;
439+
for (size_t i = 0; i < values.size(); i++) {
440+
if (i > 0) joined += ",";
441+
joined += values[i];
442+
}
443+
return StringVector::AddString(result, joined);
444+
} catch (std::exception &e) {
445+
throw InvalidInputException("Redis MGET error: %s", e.what());
446+
}
447+
});
448+
}
449+
450+
static void RedisScanFunction(DataChunk &args, ExpressionState &state, Vector &result) {
451+
auto &cursor_vector = args.data[0];
452+
auto &pattern_vector = args.data[1];
453+
auto &count_vector = args.data[2];
454+
auto &secret_vector = args.data[3];
455+
456+
BinaryExecutor::Execute<string_t, string_t, string_t>(
457+
cursor_vector, pattern_vector, result, args.size(),
458+
[&](string_t cursor, string_t pattern) {
459+
try {
460+
string host, port, password;
461+
if (!GetRedisSecret(state.GetContext(), secret_vector.GetValue(0).ToString(),
462+
host, port, password)) {
463+
throw InvalidInputException("Redis secret not found");
464+
}
465+
466+
auto count = count_vector.GetValue(0).GetValue<int64_t>();
467+
auto conn = ConnectionPool::getInstance().getConnection(host, port, password);
468+
auto response = conn->execute(RedisProtocol::formatScan(
469+
cursor.GetString(),
470+
pattern.GetString(),
471+
count
472+
));
473+
auto scan_result = RedisProtocol::parseArrayResponse(response);
474+
475+
if (scan_result.size() >= 2) {
476+
// First element is the new cursor, second element is array of keys
477+
std::string result_str = scan_result[0] + ":";
478+
auto keys = RedisProtocol::parseArrayResponse(scan_result[1]);
479+
for (size_t i = 0; i < keys.size(); i++) {
480+
if (i > 0) result_str += ",";
481+
result_str += keys[i];
482+
}
483+
return StringVector::AddString(result, result_str);
484+
}
485+
return StringVector::AddString(result, "0:");
486+
} catch (std::exception &e) {
487+
throw InvalidInputException("Redis SCAN error: %s", e.what());
488+
}
489+
});
490+
}
491+
395492
static void LoadInternal(DatabaseInstance &instance) {
396-
// Register Redis GET function
493+
// Register the secret functions first!
494+
CreateRedisSecretFunctions::Register(instance);
495+
496+
// Then register Redis functions
397497
auto redis_get_func = ScalarFunction(
398498
"redis_get",
399499
{LogicalType::VARCHAR, // key
@@ -460,8 +560,27 @@ static void LoadInternal(DatabaseInstance &instance) {
460560
);
461561
ExtensionUtil::RegisterFunction(instance, redis_lrange_func);
462562

463-
// Register the secret functions
464-
CreateRedisSecretFunctions::Register(instance);
563+
// Register MGET
564+
auto redis_mget_func = ScalarFunction(
565+
"redis_mget",
566+
{LogicalType::VARCHAR, // comma-separated keys
567+
LogicalType::VARCHAR}, // secret_name
568+
LogicalType::VARCHAR,
569+
RedisMGetFunction
570+
);
571+
ExtensionUtil::RegisterFunction(instance, redis_mget_func);
572+
573+
// Register SCAN
574+
auto redis_scan_func = ScalarFunction(
575+
"redis_scan",
576+
{LogicalType::VARCHAR, // cursor
577+
LogicalType::VARCHAR, // pattern
578+
LogicalType::BIGINT, // count
579+
LogicalType::VARCHAR}, // secret_name
580+
LogicalType::VARCHAR,
581+
RedisScanFunction
582+
);
583+
ExtensionUtil::RegisterFunction(instance, redis_scan_func);
465584
}
466585

467586
void RedisExtension::Load(DuckDB &db) {

0 commit comments

Comments
 (0)