Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ ml-algorithms/build/
plugin/build/
.DS_Store
*/bin/
**/*.factorypath
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.settings;

import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_ENDPOINT_KEY;
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL_KEY;
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_GLOBAL_TENANT_ID_KEY;
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_REGION_KEY;
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_SERVICE_NAME_KEY;
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_TYPE_KEY;
Expand Down Expand Up @@ -372,4 +374,14 @@ private MLCommonsSettings() {}
.boolSetting("plugins.ml_commons.agentic_memory_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final String ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE =
"The Agentic Memory APIs are not enabled. To enable, please update the setting " + ML_COMMONS_AGENTIC_MEMORY_ENABLED.getKey();

public static final Setting<String> REMOTE_METADATA_GLOBAL_TENANT_ID = Setting
.simpleString("plugins.ml-commons." + REMOTE_METADATA_GLOBAL_TENANT_ID_KEY, Setting.Property.NodeScope, Setting.Property.Final);

public static final Setting<String> REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL = Setting
.simpleString(
"plugins.ml-commons." + REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL_KEY,
Setting.Property.NodeScope,
Setting.Property.Final
);
}
1 change: 1 addition & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies {
// Multi-tenant SDK Client
implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
implementation 'commons-beanutils:commons-beanutils:1.11.0'
implementation "org.opensearch:opensearch-remote-metadata-sdk-ddb-client:${opensearch_build}"

def os = DefaultNativePlatform.currentOperatingSystem
//arm/macos doesn't support GPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
return predictable;
}

public void deploy(MLModel mlModel, Map<String, Object> params, ActionListener<Predictable> listener) {
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
predictable.initModelAsync(mlModel, params, encryptor).thenAccept((b) -> listener.onResponse(predictable)).exceptionally(e -> {
log.error("Failed to init init model", e);
listener.onFailure(new RuntimeException(e));
return null;
});
}

public MLExecutable deployExecute(MLModel mlModel, Map<String, Object> params) {
MLExecutable executable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
executable.initModel(mlModel, params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine;

import java.util.Map;
import java.util.concurrent.CompletionStage;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
Expand All @@ -19,6 +20,8 @@
*/
public interface Predictable {

String METHOD_NOT_IMPLEMENTED_ERROR_MSG = "Method is not implemented";

/**
* Predict with given input data and model.
* Will reload model into memory with model content.
Expand All @@ -34,11 +37,11 @@ public interface Predictable {
* @return predicted results
*/
default MLOutput predict(MLInput mlInput) {
throw new IllegalStateException("Method is not implemented");
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
}

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
actionListener.onFailure(new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG));
}

/**
Expand All @@ -47,7 +50,13 @@ default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> action
* @param params other parameters
* @param encryptor encryptor
*/
void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor);
default void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
}

default CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object> params, Encryptor encryptor) {
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
}

/**
* Close resources like deployed model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.settings.MLCommonsSettings.REMOTE_METADATA_GLOBAL_TENANT_ID;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicBoolean;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLIndex;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
Expand All @@ -28,6 +33,7 @@
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.script.ScriptService;
import org.opensearch.transport.client.Client;

Expand All @@ -47,6 +53,8 @@ public class RemoteModel implements Predictable {
public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map";
public static final String GUARDRAILS = "guardrails";
public static final String CONNECTOR_PRIVATE_IP_ENABLED = "connectorPrivateIpEnabled";
public static final String SDK_CLIENT = "sdk_client";
public static final String SETTINGS = "settings";

private RemoteConnectorExecutor connectorExecutor;

Expand Down Expand Up @@ -98,11 +106,14 @@ public boolean isModelReady() {
}

@Override
public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
try {
public CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object> params, Encryptor encryptor) {
SdkClient sdkClient = (SdkClient) params.get(SDK_CLIENT);
return sdkClient.isGlobalResource(MLIndex.MODEL.getIndexName(), model.getModelId()).thenCompose(isGlobalResource -> {
String decryptTenantId = Boolean.TRUE.equals(isGlobalResource)
? REMOTE_METADATA_GLOBAL_TENANT_ID.get((Settings) params.get(SETTINGS))
: model.getTenantId();
Connector connector = model.getConnector().cloneConnector();
connector
.decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, model.getTenantId()), model.getTenantId());
connector.decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, decryptTenantId), decryptTenantId);
// This situation can only happen for inline connector where we don't provide tenant id.
if (connector.getTenantId() == null && model.getTenantId() != null) {
connector.setTenantId(model.getTenantId());
Expand All @@ -116,13 +127,10 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
} catch (RuntimeException e) {
log.error("Failed to init remote model.", e);
throw e;
} catch (Throwable e) {
return CompletableFuture.completedStage(true);
}).exceptionally(e -> {
log.error("Failed to init remote model.", e);
throw new MLException(e);
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SDK_CLIENT;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SETTINGS;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;

import org.junit.Assert;
import org.junit.Before;
Expand All @@ -26,6 +31,7 @@
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
Expand All @@ -39,6 +45,7 @@
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.MLStaticMockBase;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.remote.metadata.client.SdkClient;

import com.google.common.collect.ImmutableMap;

Expand All @@ -53,11 +60,15 @@ public class RemoteModelTest extends MLStaticMockBase {
@Mock
RemoteConnectorExecutor remoteConnectorExecutor;

@Mock
SdkClient sdkClient;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

RemoteModel remoteModel;
Encryptor encryptor;
Settings settings = Settings.builder().put("plugins.ml-commons.global_tenant_id", "_global_tenant_id").build();

@Before
public void setUp() {
Expand All @@ -66,6 +77,7 @@ public void setUp() {

encryptor = mock(Encryptor.class);
when(encryptor.decrypt(any(), any())).thenReturn("test_api_key");
when(sdkClient.isGlobalResource(any(), any())).thenReturn(CompletableFuture.completedFuture(true));
}

@Test
Expand Down Expand Up @@ -110,7 +122,11 @@ public void asyncPredict_With_RemoteInferenceInputDataSet() {
private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
boolean initModelResult = remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
assertTrue(initModelResult);
ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
remoteModel.asyncPredict(mlInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand Down Expand Up @@ -152,7 +168,10 @@ private void asyncPredict_Failure_With_Throwable(
loader
.when(() -> MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class))
.thenReturn(remoteConnectorExecutor);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
remoteModel.asyncPredict(mlInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
Expand All @@ -162,16 +181,16 @@ private void asyncPredict_Failure_With_Throwable(
}

@Test
public void initModel_Failure_With_RuntimeException() {
initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class, "Tag mismatch!");
public void initModelAsync_Failure_With_RuntimeException() {
initModelAsync_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), CompletionException.class, "Tag mismatch!");
}

@Test
public void initModel_Failure_With_Throwable() {
initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class, "Decryption Error!");
public void initModelAsync_Failure_With_Throwable() {
initModelAsync_Failure_With_Throwable(new Error("Decryption Error!"), CompletionException.class, "Decryption Error!");
}

private void initModel_Failure_With_Throwable(
private void initModelAsync_Failure_With_Throwable(
Throwable actualException,
Class<? extends Throwable> expExcepClass,
String expExceptionMessage
Expand All @@ -181,23 +200,32 @@ private void initModel_Failure_With_Throwable(
Connector connector = createConnector(null);
when(mlModel.getConnector()).thenReturn(connector);
doThrow(actualException).when(encryptor).decrypt(any(), any());
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
}

@Test
public void initModel_NullHeader() {
public void initModelAsync_NullHeader() {
Connector connector = createConnector(null);
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
assertNull(decryptedHeaders);
}

@Test
public void initModel_WithHeader() {
public void initModelAsync_WithHeader() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
Assert.assertNotNull(executor);
Expand All @@ -210,47 +238,59 @@ public void initModel_WithHeader() {
}

@Test
public void initModel_setsTenantIdOnClonedConnector_whenMissing() {
public void initModelAsync_setsTenantIdOnClonedConnector_whenMissing() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
when(mlModel.getTenantId()).thenReturn("tenantId");
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
remoteModel.close();
assertNull(connector.getTenantId());
assertEquals("tenantId", executor.getConnector().getTenantId());
}

@Test
public void initModel_bothTenantIdsNull() {
public void initModelAsync_bothTenantIdsNull() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
when(mlModel.getTenantId()).thenReturn(null);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
assertNull(connector.getTenantId());
assertNull(executor.getConnector().getTenantId());
}

@Test
public void initModel_connectorHasTenantId() {
public void initModelAsync_connectorHasTenantId() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
connector.setTenantId("connectorTenantId");
when(mlModel.getConnector()).thenReturn(connector);
when(mlModel.getTenantId()).thenReturn(null);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
assertEquals("connectorTenantId", connector.getTenantId());
assertEquals("connectorTenantId", executor.getConnector().getTenantId());
}

@Test
public void initModel_bothHaveTenantIds() {
public void initModelAsync_bothHaveTenantIds() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
connector.setTenantId("connectorTenantId");
when(mlModel.getConnector()).thenReturn(connector);
when(mlModel.getTenantId()).thenReturn("modelTenantId");
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel
.initModelAsync(mlModel, ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, settings), encryptor)
.toCompletableFuture()
.join();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
assertEquals("connectorTenantId", connector.getTenantId());
assertEquals("connectorTenantId", executor.getConnector().getTenantId());
Expand Down
Loading