Skip to content

Commit 38b5861

Browse files
committed
added script.c
1 parent a4bc774 commit 38b5861

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

src/redis_ai_objects/script.c

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/**
2+
* script.c
3+
*
4+
* Contains the helper methods for both creating, populating,
5+
* managing and destructing the PyTorch Script data structure.
6+
*
7+
*/
8+
9+
#include <pthread.h>
10+
#include "version.h"
11+
#include "script.h"
12+
#include "script_struct.h"
13+
#include "stats.h"
14+
#include "util/arr.h"
15+
#include "util/string_utils.h"
16+
#include "rmutil/alloc.h"
17+
#include "backends/backends.h"
18+
#include "execution/DAG/dag.h"
19+
#include "execution/run_info.h"
20+
21+
extern RedisModuleType *RedisAI_ScriptType;
22+
23+
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
24+
RAI_Error *err) {
25+
if (!RAI_backends.torch.script_create) {
26+
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH");
27+
return NULL;
28+
}
29+
RAI_Script *script = RAI_backends.torch.script_create(devicestr, scriptdef, err);
30+
31+
if (script) {
32+
if (tag) {
33+
script->tag = RAI_HoldString(tag);
34+
} else {
35+
script->tag = RedisModule_CreateString(NULL, "", 0);
36+
}
37+
}
38+
39+
return script;
40+
}
41+
42+
void RAI_ScriptFree(RAI_Script *script, RAI_Error *err) {
43+
if (__atomic_sub_fetch(&script->refCount, 1, __ATOMIC_RELAXED) > 0) {
44+
return;
45+
}
46+
47+
if (!RAI_backends.torch.script_free) {
48+
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH");
49+
return;
50+
}
51+
52+
RedisModule_FreeString(NULL, script->tag);
53+
54+
RAI_RemoveStatsEntry(script->infokey);
55+
56+
RAI_backends.torch.script_free(script, err);
57+
}
58+
59+
RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script) {
60+
__atomic_fetch_add(&script->refCount, 1, __ATOMIC_RELAXED);
61+
return script;
62+
}
63+
64+
/* Return REDISMODULE_ERR if there was an error getting the Script.
65+
* Return REDISMODULE_OK if the model value stored at key was correctly
66+
* returned and available at *model variable. */
67+
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Script **script,
68+
int mode, RAI_Error *err) {
69+
RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode);
70+
71+
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
72+
RedisModule_CloseKey(key);
73+
#ifndef LITE
74+
RedisModule_Log(ctx, "warning", "could not load %s from keyspace, key doesn't exist",
75+
RedisModule_StringPtrLen(keyName, NULL));
76+
RAI_SetError(err, RAI_EKEYEMPTY, "ERR script key is empty");
77+
return REDISMODULE_ERR;
78+
#else
79+
if (VerifyKeyInThisShard(ctx, keyName)) { // Relevant for enterprise cluster.
80+
RAI_SetError(err, RAI_EKEYEMPTY, "ERR script key is empty");
81+
} else {
82+
RAI_SetError(err, RAI_EKEYEMPTY,
83+
"ERR CROSSSLOT Keys in request don't hash to the same slot");
84+
}
85+
#endif
86+
return REDISMODULE_ERR;
87+
}
88+
if (RedisModule_ModuleTypeGetType(key) != RedisAI_ScriptType) {
89+
RedisModule_CloseKey(key);
90+
RAI_SetError(err, RAI_ESCRIPTRUN, REDISMODULE_ERRORMSG_WRONGTYPE);
91+
return REDISMODULE_ERR;
92+
}
93+
*script = RedisModule_ModuleTypeGetValue(key);
94+
RedisModule_CloseKey(key);
95+
return REDISMODULE_OK;
96+
}
97+
98+
int RedisAI_ScriptRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx,
99+
RedisModuleString **argv, int argc) {
100+
RedisModule_KeyAtPos(ctx, 1);
101+
size_t startpos = 3;
102+
if (startpos >= argc) {
103+
return REDISMODULE_ERR;
104+
}
105+
const char *str = RedisModule_StringPtrLen(argv[startpos], NULL);
106+
if (!strcasecmp(str, "TIMEOUT")) {
107+
startpos += 2;
108+
}
109+
startpos += 1;
110+
if (startpos >= argc) {
111+
return REDISMODULE_ERR;
112+
}
113+
for (size_t argpos = startpos; argpos < argc; argpos++) {
114+
str = RedisModule_StringPtrLen(argv[argpos], NULL);
115+
if (!strcasecmp(str, "OUTPUTS")) {
116+
continue;
117+
}
118+
RedisModule_KeyAtPos(ctx, argpos);
119+
}
120+
return REDISMODULE_OK;
121+
}
122+
123+
int RedisAI_ScriptExecute_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx,
124+
RedisModuleString **argv, int argc) {
125+
// AI.SCRIPTEXECUTE script_name func KEYS n key....
126+
if (argc < 6) {
127+
return REDISMODULE_ERR;
128+
}
129+
RedisModule_KeyAtPos(ctx, 1);
130+
size_t argpos = 3;
131+
long long count;
132+
while (argpos < argc) {
133+
const char *str = RedisModule_StringPtrLen(argv[argpos++], NULL);
134+
135+
// Inputs, outpus, keys, lists.
136+
if ((!strcasecmp(str, "INPUTS")) || (!strcasecmp(str, "OUTPUTS")) ||
137+
(!strcasecmp(str, "LIST_INPUTS")) || (!strcasecmp(str, "KEYS"))) {
138+
bool updateKeyAtPos = false;
139+
// The only scope where the inputs strings are 100% keys are in the KEYS and OUTPUTS
140+
// scopes.
141+
if ((!strcasecmp(str, "OUTPUTS")) || (!strcasecmp(str, "KEYS"))) {
142+
updateKeyAtPos = true;
143+
}
144+
if (argpos >= argc) {
145+
return REDISMODULE_ERR;
146+
}
147+
if (RedisModule_StringToLongLong(argv[argpos++], &count) != REDISMODULE_OK) {
148+
return REDISMODULE_ERR;
149+
}
150+
if (count <= 0) {
151+
return REDISMODULE_ERR;
152+
}
153+
if (argpos + count >= argc) {
154+
return REDISMODULE_ERR;
155+
}
156+
for (long long i = 0; i < count; i++) {
157+
if (updateKeyAtPos) {
158+
RedisModule_KeyAtPos(ctx, argpos);
159+
}
160+
argpos++;
161+
}
162+
continue;
163+
}
164+
// Timeout
165+
if (!strcasecmp(str, "TIMEOUT")) {
166+
argpos++;
167+
break;
168+
}
169+
// Undefinded input.
170+
return REDISMODULE_ERR;
171+
}
172+
if (argpos != argc) {
173+
return REDISMODULE_ERR;
174+
} else {
175+
return REDISMODULE_OK;
176+
}
177+
}
178+
179+
RedisModuleType *RAI_ScriptRedisType(void) { return RedisAI_ScriptType; }

0 commit comments

Comments
 (0)