Skip to content

Commit

Permalink
[examples] fixing some dependencies and test requirements (#3605)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Feb 14, 2025
1 parent a9a0c12 commit 714e0ad
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
<artifactId>commons-cli</artifactId>
<version>1.9.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.17.0</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j2-impl</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.annotations.Test;
Expand All @@ -22,6 +23,8 @@ public class TrainBertOnGoemotionsTest {

@Test
public void testTrainBert() throws IOException, TranslateException {
TestRequirements.engine("PyTorch", "OnnxRuntime");

String[] args = {"-g", "1", "-m", "1", "-e", "1"};
TrainBertOnGoemotions.runExample(args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.annotations.Test;
Expand All @@ -22,6 +23,8 @@ public class TrainBertTest {

@Test
public void testTrainBert() throws IOException, TranslateException {
TestRequirements.engine("PyTorch", "OnnxRuntime");

String[] args = {"-g", "1", "-m", "1", "-e", "1"};
TrainBertOnCode.runExample(args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.ModelException;
import ai.djl.examples.inference.cv.ImageClassification;
import ai.djl.modality.Classifications;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -27,6 +28,8 @@ public class TrainMnistTest {

@Test
public void testTrainMnist() throws ModelException, TranslateException, IOException {
TestRequirements.engine("PyTorch", "TensorFlow", "OnnxRuntime");

double expectedProb;
if (Boolean.getBoolean("nightly")) {
String[] args = new String[] {"-g", "1"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -24,6 +25,8 @@ public class TrainMnistWithLSTMTest {

@Test
public void testTrainMnistWithLSTM() throws IOException, TranslateException {
TestRequirements.engine("PyTorch", "TensorFlow", "OnnxRuntime");

String[] args = {"-g", "1", "-e", "1", "-m", "2"};
TrainingResult result = TrainMnistWithLSTM.runExample(args);
Assert.assertNotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.examples.training;

import ai.djl.engine.Engine;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;

import org.testng.Assert;
Expand All @@ -24,6 +25,8 @@ public class TrainTicTacToeTest {

@Test
public void testTrainTicTacToe() throws IOException {
TestRequirements.engine("PyTorch");

Engine engine = Engine.getEngine("PyTorch");
if (Boolean.getBoolean("nightly") && engine.getGpuCount() > 0) {
String[] args = new String[] {"-g", "1", "-e", "6"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -25,6 +26,8 @@ public class TrainTimeSeriesTest {

@Test
public void testTrainTimeSeries() throws TranslateException, IOException {
TestRequirements.engine("PyTorch", "OnnxRuntime");

String[] args = {"-g", "1", "-e", "5", "-b", "32"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
Expand Down

0 comments on commit 714e0ad

Please sign in to comment.