Skip to content

Provide a way for customer to inject their own data converter. #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
import com.microsoft.azure.functions.internal.spi.middleware.Middleware;
import com.microsoft.azure.functions.internal.spi.middleware.MiddlewareChain;
import com.microsoft.azure.functions.internal.spi.middleware.MiddlewareContext;
import com.microsoft.durabletask.DataConverter;
import com.microsoft.durabletask.OrchestrationRunner;
import com.microsoft.durabletask.OrchestratorBlockedException;

import java.util.Iterator;
import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* Durable Function Orchestration Middleware
*
Expand All @@ -21,14 +26,19 @@
public class OrchestrationMiddleware implements Middleware {

private static final String ORCHESTRATION_TRIGGER = "DurableOrchestrationTrigger";
private final Object dataConverterLock = new Object();
private volatile DataConverter dataConverter;
private final AtomicBoolean oneTimeLogicExecuted = new AtomicBoolean(false);

@Override
public void invoke(MiddlewareContext context, MiddlewareChain chain) throws Exception {
String parameterName = context.getParameterName(ORCHESTRATION_TRIGGER);
if (parameterName == null){
if (parameterName == null) {
chain.doNext(context);
return;
}
//invoked only for orchestrator function.
loadCustomizedDataConverterOnce();
String orchestratorRequestEncodedProtoBytes = (String) context.getParameterValue(parameterName);
String orchestratorOutputEncodedProtoBytes = OrchestrationRunner.loadAndRun(orchestratorRequestEncodedProtoBytes, taskOrchestrationContext -> {
try {
Expand All @@ -39,12 +49,29 @@ public void invoke(MiddlewareContext context, MiddlewareChain chain) throws Exce
// The OrchestratorBlockedEvent will be wrapped into InvocationTargetException by using reflection to
// invoke method. Thus get the cause to check if it's OrchestratorBlockedEvent.
Throwable cause = e.getCause();
if (cause instanceof OrchestratorBlockedException){
if (cause instanceof OrchestratorBlockedException) {
throw (OrchestratorBlockedException) cause;
}
throw new RuntimeException("Unexpected failure in the task execution", e);
}
});
}, this.dataConverter);
context.updateReturnValue(orchestratorOutputEncodedProtoBytes);
}

private void loadCustomizedDataConverterOnce() {
if (!oneTimeLogicExecuted.get()) {
synchronized (dataConverterLock) {
if (!oneTimeLogicExecuted.get()) {
Iterator<DataConverter> iterator = ServiceLoader.load(DataConverter.class).iterator();
if (iterator.hasNext()) {
this.dataConverter = iterator.next();
if (iterator.hasNext()) {
throw new IllegalStateException("Multiple implementations of DataConverter found on the classpath.");
}
}
oneTimeLogicExecuted.compareAndSet(false,true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why you can't use a regular Boolean here? Since we're already in a synchronized block, I wouldn't expect that you'd need to use any additional synchronization primitives.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I use a regular Boolean in java, I need to declare it with volatile which avoids using thread cache, that means that when a thread modifies the value of this Boolean, all other threads will see the updated value when they access it in Java. So I am thinking of just using AtomicBoolean which provides both visibility and atomicity at the same time.

}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ private OrchestrationRunner() {
*/
public static <R> String loadAndRun(
String base64EncodedOrchestratorRequest,
OrchestratorFunction<R> orchestratorFunc) {
OrchestratorFunction<R> orchestratorFunc,
DataConverter dataConverter) {
// Example string: CiBhOTMyYjdiYWM5MmI0MDM5YjRkMTYxMDIwNzlmYTM1YSIaCP///////////wESCwi254qRBhDk+rgocgAicgj///////////8BEgwIs+eKkQYQzMXjnQMaVwoLSGVsbG9DaXRpZXMSACJGCiBhOTMyYjdiYWM5MmI0MDM5YjRkMTYxMDIwNzlmYTM1YRIiCiA3ODEwOTA2N2Q4Y2Q0ODg1YWU4NjQ0OTNlMmRlMGQ3OA==
byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedOrchestratorRequest);
byte[] resultBytes = loadAndRun(decodedBytes, orchestratorFunc);
byte[] resultBytes = loadAndRun(decodedBytes, orchestratorFunc, dataConverter);
return Base64.getEncoder().encodeToString(resultBytes);
}

Expand All @@ -55,7 +56,8 @@ public static <R> String loadAndRun(
*/
public static <R> byte[] loadAndRun(
byte[] orchestratorRequestBytes,
OrchestratorFunction<R> orchestratorFunc) {
OrchestratorFunction<R> orchestratorFunc,
DataConverter dataConverter) {
if (orchestratorFunc == null) {
throw new IllegalArgumentException("orchestratorFunc must not be null");
}
Expand All @@ -66,7 +68,7 @@ public static <R> byte[] loadAndRun(
ctx.complete(output);
};

return loadAndRun(orchestratorRequestBytes, orchestration);
return loadAndRun(orchestratorRequestBytes, orchestration, dataConverter);
}

/**
Expand All @@ -82,7 +84,7 @@ public static String loadAndRun(
String base64EncodedOrchestratorRequest,
TaskOrchestration orchestration) {
byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedOrchestratorRequest);
byte[] resultBytes = loadAndRun(decodedBytes, orchestration);
byte[] resultBytes = loadAndRun(decodedBytes, orchestration, null);
return Base64.getEncoder().encodeToString(resultBytes);
}

Expand All @@ -95,7 +97,7 @@ public static String loadAndRun(
* @return a protobuf-encoded payload of orchestrator actions to be interpreted by the external orchestration engine
* @throws IllegalArgumentException if either parameter is {@code null} or if {@code orchestratorRequestBytes} is not valid protobuf
*/
public static byte[] loadAndRun(byte[] orchestratorRequestBytes, TaskOrchestration orchestration) {
public static byte[] loadAndRun(byte[] orchestratorRequestBytes, TaskOrchestration orchestration, DataConverter dataConverter) {
if (orchestratorRequestBytes == null || orchestratorRequestBytes.length == 0) {
throw new IllegalArgumentException("triggerStateProtoBytes must not be null or empty");
}
Expand Down Expand Up @@ -127,7 +129,7 @@ public TaskOrchestration create() {

TaskOrchestrationExecutor taskOrchestrationExecutor = new TaskOrchestrationExecutor(
orchestrationFactories,
new JacksonDataConverter(),
dataConverter != null ? dataConverter : new JacksonDataConverter(),
DEFAULT_MAXIMUM_TIMER_INTERVAL,
logger);

Expand Down
1 change: 1 addition & 0 deletions samples-azure-functions/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
implementation project(':azurefunctions')

implementation 'com.microsoft.azure.functions:azure-functions-java-library:3.0.0'
implementation 'com.google.code.gson:gson:2.9.0'
testImplementation 'org.junit.jupiter:junit-jupiter:5.6.2'
testImplementation 'io.rest-assured:rest-assured:5.3.0'
testImplementation 'io.rest-assured:json-path:5.3.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ public HttpResponseMessage startOrchestration(
*/
@FunctionName("Cities")
public String citiesOrchestrator(
@DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) {
@DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx,
final ExecutionContext context) {
String result = "";
result += ctx.callActivity("Capitalize", "Tokyo", String.class).await() + ", ";
result += ctx.callActivity("Capitalize", "London", String.class).await() + ", ";
result += ctx.callActivity("Capitalize", "Seattle", String.class).await() + ", ";
result += ctx.callActivity("Capitalize", "Austin", String.class).await();
context.getLogger().info("Orchestrator function completed!");
return result;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.functions;

import com.microsoft.azure.functions.ExecutionContext;
import com.microsoft.azure.functions.HttpMethod;
import com.microsoft.azure.functions.HttpRequestMessage;
import com.microsoft.azure.functions.HttpResponseMessage;
import com.microsoft.azure.functions.annotation.AuthorizationLevel;
import com.microsoft.azure.functions.annotation.FunctionName;
import com.microsoft.azure.functions.annotation.HttpTrigger;
import com.microsoft.durabletask.DurableTaskClient;
import com.microsoft.durabletask.TaskOrchestrationContext;
import com.microsoft.durabletask.azurefunctions.DurableActivityTrigger;
import com.microsoft.durabletask.azurefunctions.DurableClientContext;
import com.microsoft.durabletask.azurefunctions.DurableClientInput;
import com.microsoft.durabletask.azurefunctions.DurableOrchestrationTrigger;

import java.time.LocalDate;
import java.util.Optional;

public class CustomizeDataConverter {

@FunctionName("StartCustomize")
public HttpResponseMessage startExampleProcess(
@HttpTrigger(name = "req",
methods = {HttpMethod.GET, HttpMethod.POST},
authLevel = AuthorizationLevel.ANONYMOUS) final HttpRequestMessage<Optional<String>> request,
@DurableClientInput(name = "durableContext") final DurableClientContext durableContext,
final ExecutionContext context) {
context.getLogger().info("Java HTTP trigger processed a request");

final DurableTaskClient client = durableContext.getClient();
final String instanceId = client.scheduleNewOrchestrationInstance("Customize");
return durableContext.createCheckStatusResponse(request, instanceId);
}

@FunctionName("Customize")
public ExampleResponse exampleOrchestrator(
@DurableOrchestrationTrigger(name = "taskOrchestrationContext") final TaskOrchestrationContext context,
final ExecutionContext functionContext) {
return context.callActivity("ToLower", "Foo", ExampleResponse.class).await();
}

@FunctionName("ToLower")
public ExampleResponse toLower(
@DurableActivityTrigger(name = "value") final String value,
final ExecutionContext context) {
return new ExampleResponse(LocalDate.now(), value.toLowerCase());
}

static class ExampleResponse {
private final LocalDate date;
private final String value;

public ExampleResponse(LocalDate date, String value) {
this.date = date;
this.value = value;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.functions.converter;

import com.google.gson.Gson;
import com.microsoft.durabletask.DataConverter;

public class MyConverter implements DataConverter {

private static final Gson gson = new Gson();
@Override
public String serialize(Object value) {
return gson.toJson(value);
}

@Override
public <T> T deserialize(String data, Class<T> target) {
return gson.fromJson(data, target);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
com.functions.converter.MyConverter
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
@Tag("e2e")
public class EndToEndTests {

private String waitForCompletion(String statusQueryGetUri) throws InterruptedException {
String runTimeStatus = null;
for (int i = 0; i < 15; i++) {
Response statusResponse = get(statusQueryGetUri);
runTimeStatus = statusResponse.jsonPath().get("runtimeStatus");
if (!"Completed".equals(runTimeStatus)) {
Thread.sleep(1000);
} else break;
}
return runTimeStatus;
}

@Order(1)
@Test
public void setupHost() {
Expand Down Expand Up @@ -82,16 +94,13 @@ public void restart(boolean restartWithNewInstanceId) throws InterruptedExceptio
}
}

private String waitForCompletion(String statusQueryGetUri) throws InterruptedException {
String runTimeStatus = null;
for (int i = 0; i < 15; i++) {
Response statusResponse = get(statusQueryGetUri);
runTimeStatus = statusResponse.jsonPath().get("runtimeStatus");
if (!"Completed".equals(runTimeStatus)) {
Thread.sleep(1000);
} else break;
}
return runTimeStatus;
@Test
public void customizeDataConverter() throws InterruptedException {
String startOrchestrationPath = "/api/StartCustomize";
Response response = post(startOrchestrationPath);
JsonPath jsonPath = response.jsonPath();
String statusQueryGetUri = jsonPath.get("statusQueryGetUri");
String runTimeStatus = waitForCompletion(statusQueryGetUri);
assertEquals("Completed", runTimeStatus);
}

}