@@ -91,10 +91,14 @@ class RedisProtocol {
91
91
92
92
// Key scanning
93
93
static std::string formatScan (const std::string& cursor, const std::string& pattern = " *" , int64_t count = 10 ) {
94
+ std::string cmd = " *6\r\n $4\r\n SCAN\r\n " ;
95
+ cmd += " $" + std::to_string (cursor.length ()) + " \r\n " + cursor + " \r\n " ;
96
+ cmd += " $5\r\n MATCH\r\n " ;
97
+ cmd += " $" + std::to_string (pattern.length ()) + " \r\n " + pattern + " \r\n " ;
98
+ cmd += " $5\r\n COUNT\r\n " ;
94
99
auto count_str = std::to_string (count);
95
- return " *6\r\n $4\r\n SCAN\r\n $" + std::to_string (cursor.length ()) + " \r\n " + cursor +
96
- " \r\n $5\r\n MATCH\r\n $" + std::to_string (pattern.length ()) + " \r\n " + pattern +
97
- " \r\n $5\r\n COUNT\r\n $" + std::to_string (count_str.length ()) + " \r\n " + count_str + " \r\n " ;
100
+ cmd += " $" + std::to_string (count_str.length ()) + " \r\n " + count_str + " \r\n " ;
101
+ return cmd;
98
102
}
99
103
100
104
static std::vector<std::string> parseArrayResponse (const std::string& response) {
@@ -120,6 +124,14 @@ class RedisProtocol {
120
124
}
121
125
return result;
122
126
}
127
+
128
+ static std::string formatMGet (const std::vector<std::string>& keys) {
129
+ std::string cmd = " *" + std::to_string (keys.size () + 1 ) + " \r\n $4\r\n MGET\r\n " ;
130
+ for (const auto & key : keys) {
131
+ cmd += " $" + std::to_string (key.length ()) + " \r\n " + key + " \r\n " ;
132
+ }
133
+ return cmd;
134
+ }
123
135
};
124
136
125
137
// Redis connection class
@@ -392,8 +404,96 @@ static void RedisLRangeFunction(DataChunk &args, ExpressionState &state, Vector
392
404
});
393
405
}
394
406
407
+ static void RedisMGetFunction (DataChunk &args, ExpressionState &state, Vector &result) {
408
+ auto &keys_list = args.data [0 ];
409
+ auto &secret_vector = args.data [1 ];
410
+
411
+ UnaryExecutor::Execute<string_t , string_t >(
412
+ keys_list, result, args.size (),
413
+ [&](string_t keys_str) {
414
+ try {
415
+ // Split comma-separated keys
416
+ std::vector<std::string> keys;
417
+ std::string key_list = keys_str.GetString ();
418
+ size_t pos = 0 ;
419
+ while ((pos = key_list.find (' ,' )) != std::string::npos) {
420
+ keys.push_back (key_list.substr (0 , pos));
421
+ key_list.erase (0 , pos + 1 );
422
+ }
423
+ if (!key_list.empty ()) {
424
+ keys.push_back (key_list);
425
+ }
426
+
427
+ string host, port, password;
428
+ if (!GetRedisSecret (state.GetContext (), secret_vector.GetValue (0 ).ToString (),
429
+ host, port, password)) {
430
+ throw InvalidInputException (" Redis secret not found" );
431
+ }
432
+
433
+ auto conn = ConnectionPool::getInstance ().getConnection (host, port, password);
434
+ auto response = conn->execute (RedisProtocol::formatMGet (keys));
435
+ auto values = RedisProtocol::parseArrayResponse (response);
436
+
437
+ // Join results with comma
438
+ std::string joined;
439
+ for (size_t i = 0 ; i < values.size (); i++) {
440
+ if (i > 0 ) joined += " ," ;
441
+ joined += values[i];
442
+ }
443
+ return StringVector::AddString (result, joined);
444
+ } catch (std::exception &e) {
445
+ throw InvalidInputException (" Redis MGET error: %s" , e.what ());
446
+ }
447
+ });
448
+ }
449
+
450
+ static void RedisScanFunction (DataChunk &args, ExpressionState &state, Vector &result) {
451
+ auto &cursor_vector = args.data [0 ];
452
+ auto &pattern_vector = args.data [1 ];
453
+ auto &count_vector = args.data [2 ];
454
+ auto &secret_vector = args.data [3 ];
455
+
456
+ BinaryExecutor::Execute<string_t , string_t , string_t >(
457
+ cursor_vector, pattern_vector, result, args.size (),
458
+ [&](string_t cursor, string_t pattern) {
459
+ try {
460
+ string host, port, password;
461
+ if (!GetRedisSecret (state.GetContext (), secret_vector.GetValue (0 ).ToString (),
462
+ host, port, password)) {
463
+ throw InvalidInputException (" Redis secret not found" );
464
+ }
465
+
466
+ auto count = count_vector.GetValue (0 ).GetValue <int64_t >();
467
+ auto conn = ConnectionPool::getInstance ().getConnection (host, port, password);
468
+ auto response = conn->execute (RedisProtocol::formatScan (
469
+ cursor.GetString (),
470
+ pattern.GetString (),
471
+ count
472
+ ));
473
+ auto scan_result = RedisProtocol::parseArrayResponse (response);
474
+
475
+ if (scan_result.size () >= 2 ) {
476
+ // First element is the new cursor, second element is array of keys
477
+ std::string result_str = scan_result[0 ] + " :" ;
478
+ auto keys = RedisProtocol::parseArrayResponse (scan_result[1 ]);
479
+ for (size_t i = 0 ; i < keys.size (); i++) {
480
+ if (i > 0 ) result_str += " ," ;
481
+ result_str += keys[i];
482
+ }
483
+ return StringVector::AddString (result, result_str);
484
+ }
485
+ return StringVector::AddString (result, " 0:" );
486
+ } catch (std::exception &e) {
487
+ throw InvalidInputException (" Redis SCAN error: %s" , e.what ());
488
+ }
489
+ });
490
+ }
491
+
395
492
static void LoadInternal (DatabaseInstance &instance) {
396
- // Register Redis GET function
493
+ // Register the secret functions first!
494
+ CreateRedisSecretFunctions::Register (instance);
495
+
496
+ // Then register Redis functions
397
497
auto redis_get_func = ScalarFunction (
398
498
" redis_get" ,
399
499
{LogicalType::VARCHAR, // key
@@ -460,8 +560,27 @@ static void LoadInternal(DatabaseInstance &instance) {
460
560
);
461
561
ExtensionUtil::RegisterFunction (instance, redis_lrange_func);
462
562
463
- // Register the secret functions
464
- CreateRedisSecretFunctions::Register (instance);
563
+ // Register MGET
564
+ auto redis_mget_func = ScalarFunction (
565
+ " redis_mget" ,
566
+ {LogicalType::VARCHAR, // comma-separated keys
567
+ LogicalType::VARCHAR}, // secret_name
568
+ LogicalType::VARCHAR,
569
+ RedisMGetFunction
570
+ );
571
+ ExtensionUtil::RegisterFunction (instance, redis_mget_func);
572
+
573
+ // Register SCAN
574
+ auto redis_scan_func = ScalarFunction (
575
+ " redis_scan" ,
576
+ {LogicalType::VARCHAR, // cursor
577
+ LogicalType::VARCHAR, // pattern
578
+ LogicalType::BIGINT, // count
579
+ LogicalType::VARCHAR}, // secret_name
580
+ LogicalType::VARCHAR,
581
+ RedisScanFunction
582
+ );
583
+ ExtensionUtil::RegisterFunction (instance, redis_scan_func);
465
584
}
466
585
467
586
void RedisExtension::Load (DuckDB &db) {
0 commit comments