Skip to content

Commit

Permalink
Updating FineTuning API including DPO
Browse files Browse the repository at this point in the history
  • Loading branch information
sashirestela committed Dec 21, 2024
1 parent 5b5ed48 commit 8a891b2
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class FineTuning {
private List<Integration> integrations;
private Integer seed;
private Integer estimatedFinish;
private MethodFineTunning method;

@NoArgsConstructor
@Getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public class FineTuningRequest {

private String validationFile;

/**
* @deprecated OpenAI has deperecated this field in favor of method, and should be passed in under
* the method parameter.
*/
@Deprecated(since = "3.12.0", forRemoval = true)
private HyperParams hyperparameters;

private String suffix;
Expand All @@ -34,4 +39,6 @@ public class FineTuningRequest {

private Integer seed;

private MethodFineTunning method;

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class HyperParams {

@ObjectType(baseClass = Integer.class)
@ObjectType(baseClass = String.class)
private Object beta;

@ObjectType(baseClass = Integer.class)
@ObjectType(baseClass = String.class)
private Object batchSize;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.github.sashirestela.openai.domain.finetuning;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;

@Getter
@ToString
@NoArgsConstructor
@JsonInclude(Include.NON_EMPTY)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class MethodFineTunning {

private MethodType type;
private Supervised supervised;
private Dpo dpo;

private MethodFineTunning(MethodType type, Supervised supervised, Dpo dpo) {
this.type = type;
this.supervised = supervised;
this.dpo = dpo;
}

public static MethodFineTunning supervised(HyperParams hyperParameters) {
return new MethodFineTunning(MethodType.SUPERVISED, new Supervised(hyperParameters), null);
}

public static MethodFineTunning dpo(HyperParams hyperParameters) {
return new MethodFineTunning(MethodType.DPO, null, new Dpo(hyperParameters));
}

public enum MethodType {

@JsonProperty("supervised")
SUPERVISED,

@JsonProperty("dpo")
DPO;

}

@Getter
@ToString
@NoArgsConstructor
@JsonInclude(Include.NON_EMPTY)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public static class Supervised {

HyperParams hyperParameters;

public Supervised(HyperParams hyperParameters) {
this.hyperParameters = hyperParameters;
}

}

@Getter
@ToString
@NoArgsConstructor
@JsonInclude(Include.NON_EMPTY)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public static class Dpo {

HyperParams hyperParameters;

public Dpo(HyperParams hyperParameters) {
this.hyperParameters = hyperParameters;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,43 @@ static void setup() {
}

@Test
void testFineTuningsCreate() throws IOException {
void testFineTuningsCreateDpo() throws IOException {
DomainTestingHelper.get().mockForObject(httpClient, "src/test/resources/finetunings_create.json");
var fineTuningRequest = FineTuningRequest.builder()
.trainingFile("fileId")
.validationFile("fileId")
.model("gpt-3.5-turbo-1106")
.hyperparameters(HyperParams.builder()
.suffix("suffix")
.integration(Integration.builder()
.type(IntegrationType.WANDB)
.wandb(WandbIntegration.builder()
.project("my-wandb-project")
.name("ft-run-display-name")
.entity("testing")
.tag("first-experiment")
.tag("v2")
.build())
.build())
.seed(99)
.method(MethodFineTunning.dpo(HyperParams.builder()
.beta("auto")
.batchSize("auto")
.learningRateMultiplier("auto")
.nEpochs("auto")
.build())
.build()))
.build();
var fineTuningResponse = openAI.fineTunings().create(fineTuningRequest).join();
System.out.println(fineTuningResponse);
assertNotNull(fineTuningResponse);
}

@Test
void testFineTuningsCreateSupervised() throws IOException {
DomainTestingHelper.get().mockForObject(httpClient, "src/test/resources/finetunings_create.json");
var fineTuningRequest = FineTuningRequest.builder()
.trainingFile("fileId")
.validationFile("fileId")
.model("gpt-3.5-turbo-1106")
.suffix("suffix")
.integration(Integration.builder()
.type(IntegrationType.WANDB)
Expand All @@ -54,6 +80,11 @@ void testFineTuningsCreate() throws IOException {
.build())
.build())
.seed(99)
.method(MethodFineTunning.supervised(HyperParams.builder()
.batchSize("auto")
.learningRateMultiplier("auto")
.nEpochs("auto")
.build()))
.build();
var fineTuningResponse = openAI.fineTunings().create(fineTuningRequest).join();
System.out.println(fineTuningResponse);
Expand Down
2 changes: 1 addition & 1 deletion src/test/resources/finetunings_create.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{ "object": "fine_tuning.job", "id": "ftjob-35j8EBrZVsuyFFe2OD8Tkmvd", "model": "gpt-3.5-turbo-1106", "created_at": 1700533111, "finished_at": null, "fine_tuned_model": null, "organization_id": "org-4WdgDKZ75eLPEH6zqX5hFd5e", "result_files": [], "status": "validating_files", "validation_file": null, "training_file": "file-0e5BDWQYA1KsguTJRCCXqAa2", "hyperparameters": { "n_epochs": "auto", "batch_size": "auto", "learning_rate_multiplier": "auto" }, "trained_tokens": null, "error": null, "integrations":[{"type":"wandb","wandb":{"project":"my-wandb-project","name":"ft-run-display-name","tags":["first-experiment","v2"]}}], "seed": 99 }
{"object":"fine_tuning.job","id":"ftjob-35j8EBrZVsuyFFe2OD8Tkmvd","model":"gpt-3.5-turbo-1106","created_at":1700533111,"finished_at":null,"fine_tuned_model":null,"organization_id":"org-4WdgDKZ75eLPEH6zqX5hFd5e","result_files":[],"status":"validating_files","validation_file":null,"training_file":"file-0e5BDWQYA1KsguTJRCCXqAa2","trained_tokens":null,"error":null,"integrations":[{"type":"wandb","wandb":{"project":"my-wandb-project","name":"ft-run-display-name","tags":["first-experiment","v2"]}}],"seed":99,"method":{"type":"supervised","supervised":{"hyperparameters":{"batch_size":"auto","learning_rate_multiplier":"auto","n_epochs":"auto"}}}}

0 comments on commit 8a891b2

Please sign in to comment.