Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-33003] Support isolation forest algorithm in Flink ML ([fix] related ListStateWithCache) #255

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.examples.anomalydetection;

import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForest;
import org.apache.flink.ml.anomalydetection.isolationforest.IsolationForestModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates an IsolationForest instance and uses it for anomaly detection. */
public class IsolationForestExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

// Generates input data.
DataStream<DenseVector> inputStream =
env.fromElements(
Vectors.dense(1, 2),
Vectors.dense(1.1, 2),
Vectors.dense(1, 2.1),
Vectors.dense(1.1, 2.1),
Vectors.dense(0.1, 0.1));

Table inputTable = tEnv.fromDataStream(inputStream).as("features");

IsolationForest isolationForest =
new IsolationForest()
.setNumTrees(100)
.setMaxIter(10)
.setMaxSamples(256)
.setMaxFeatures(1.0);

IsolationForestModel isolationForestModel = isolationForest.fit(inputTable);

Table outputTable = isolationForestModel.transform(inputTable)[0];

// Extracts and displays the results.
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features = (DenseVector) row.getField(isolationForest.getFeaturesCol());
int predictId = (Integer) row.getField(isolationForest.getPredictionCol());
System.out.printf("Features: %s \tPrediction: %s\n", features, predictId);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.anomalydetection.isolationforest;

import org.apache.flink.ml.linalg.DenseVector;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/** Construct isolation forest. */
public class IForest implements Serializable {
public final int numTrees;
public List<ITree> iTreeList;
public Double center0;
public Double center1;
public int subSamplesSize;

public IForest(int numTrees) {
this.numTrees = numTrees;
this.iTreeList = new ArrayList<>(256);
this.center0 = null;
this.center1 = null;
}

public void generateIsolationForest(DenseVector[] samplesData, int[] featureIndices) {
int n = samplesData.length;
subSamplesSize = Math.min(256, n);
int limitHeight = (int) Math.ceil(Math.log(Math.max(subSamplesSize, 2)) / Math.log(2));

Random randomState = new Random(System.nanoTime());
for (int i = 0; i < numTrees; i++) {
DenseVector[] subSamples = new DenseVector[subSamplesSize];
for (int j = 0; j < subSamplesSize; j++) {
int r = randomState.nextInt(n);
subSamples[j] = samplesData[r];
}
ITree isolationTree =
ITree.generateIsolationTree(
subSamples, 0, limitHeight, randomState, featureIndices);
this.iTreeList.add(isolationTree);
}
}

public DenseVector calculateScore(DenseVector[] samplesData) throws Exception {
DenseVector score = new DenseVector(samplesData.length);

for (int i = 0; i < samplesData.length; i++) {
double pathLengthSum = 0;
for (ITree isolationTree : iTreeList) {
pathLengthSum += ITree.calculatePathLength(samplesData[i], isolationTree);
}

double pathLengthAvg = pathLengthSum / iTreeList.size();
double cn = ITree.calculateCn(subSamplesSize);
double index = pathLengthAvg / cn;
score.set(i, Math.pow(2, -index));
}

return score;
}

public DenseVector classifyByCluster(DenseVector score, int iters) {
int scoresSize = score.size();

center0 = score.get(0); // Cluster center of abnormal
center1 = score.get(0); // Cluster center of normal

for (double s : score.values) {
if (s > center0) {
center0 = s;
}

if (s < center1) {
center1 = s;
}
}

int cnt0;
int cnt1;
double diff0;
double diff1;
int[] labels = new int[scoresSize];

for (int i = 0; i < iters; i++) {
cnt0 = 0;
cnt1 = 0;

for (int j = 0; j < scoresSize; j++) {
diff0 = Math.abs(score.get(j) - center0);
diff1 = Math.abs(score.get(j) - center1);

if (diff0 < diff1) {
labels[j] = 0;
cnt0++;
} else {
labels[j] = 1;
cnt1++;
}
}

diff0 = center0;
diff1 = center1;

center0 = 0.0;
center1 = 0.0;

for (int k = 0; k < scoresSize; k++) {
if (labels[k] == 0) {
center0 += score.get(k);
} else {
center1 += score.get(k);
}
}

center0 /= cnt0;
center1 /= cnt1;

if (center0 - diff0 <= 1e-6 && center1 - diff1 <= 1e-6) {
break;
}
}

return new DenseVector(new double[] {center0, center1});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.ml.anomalydetection.isolationforest;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.DenseVector;

import java.io.Serializable;
import java.util.Random;

/** Construct isolation tree. */
public class ITree implements Serializable {
public final int attributeIndex;
public final double splitAttributeValue;
public final int currentHeight;
public final int leafNodesNum;
public ITree leftTree;
public ITree rightTree;

public ITree(
int attributeIndex, double splitAttributeValue, int currentHeight, int leafNodesNum) {
this.attributeIndex = attributeIndex;
this.splitAttributeValue = splitAttributeValue;
this.currentHeight = currentHeight;
this.leafNodesNum = leafNodesNum;
this.leftTree = null;
this.rightTree = null;
}

public static ITree generateIsolationTree(
DenseVector[] samplesData,
int currentHeight,
int limitHeight,
Random randomState,
int[] featureIndices) {
if (samplesData.length == 0) {
return null;
} else if (samplesData.length == 1 || currentHeight >= limitHeight) {
return new ITree(0, samplesData[0].get(0), currentHeight, samplesData.length);
}

boolean flag = true;
for (int i = 1; i < samplesData.length; i++) {
if (!samplesData[i].equals(samplesData[i - 1])) {
flag = false;
break;
}
}

if (flag) {
return new ITree(0, samplesData[0].get(0), currentHeight, samplesData.length);
}

Tuple2<Integer, Double> tuple2 =
getRandomFeatureToSplit(samplesData, randomState, featureIndices);
int attributeIndex = tuple2.f0;
double splitAttributeValue = tuple2.f1;

int leftNodesNum = 0;
int rightNodesNum = 0;
for (DenseVector datum : samplesData) {
if (datum.get(attributeIndex) < splitAttributeValue) {
leftNodesNum++;
} else {
rightNodesNum++;
}
}

DenseVector[] leftSamples = new DenseVector[leftNodesNum];
DenseVector[] rightSamples = new DenseVector[rightNodesNum];
int l = 0, r = 0;
for (DenseVector samplesDatum : samplesData) {
if (samplesDatum.get(attributeIndex) < splitAttributeValue) {
leftSamples[l++] = samplesDatum;
} else {
rightSamples[r++] = samplesDatum;
}
}

ITree root =
new ITree(attributeIndex, splitAttributeValue, currentHeight, samplesData.length);
root.leftTree =
generateIsolationTree(
leftSamples, currentHeight + 1, limitHeight, randomState, featureIndices);
root.rightTree =
generateIsolationTree(
rightSamples, currentHeight + 1, limitHeight, randomState, featureIndices);

return root;
}

private static Tuple2<Integer, Double> getRandomFeatureToSplit(
DenseVector[] samplesData, Random randomState, int[] featureIndices) {
int attributeIndex = featureIndices[randomState.nextInt(featureIndices.length)];

double maxValue = samplesData[0].get(attributeIndex);
double minValue = samplesData[0].get(attributeIndex);
for (int i = 1; i < samplesData.length; i++) {
minValue = Math.min(minValue, samplesData[i].get(attributeIndex));
maxValue = Math.max(maxValue, samplesData[i].get(attributeIndex));
}
double splitAttributeValue = (maxValue - minValue) * randomState.nextDouble() + minValue;

return Tuple2.of(attributeIndex, splitAttributeValue);
}

public static double calculatePathLength(DenseVector sampleData, ITree isolationTree)
throws Exception {
double pathLength = -1;
ITree tmpITree = isolationTree;
while (tmpITree != null) {
pathLength += 1;
if (tmpITree.leftTree == null
|| tmpITree.rightTree == null
|| sampleData.get(tmpITree.attributeIndex) == tmpITree.splitAttributeValue) {
break;
} else if (sampleData.get(tmpITree.attributeIndex) < tmpITree.splitAttributeValue) {
tmpITree = tmpITree.leftTree;
} else {
tmpITree = tmpITree.rightTree;
}
}

return pathLength + calculateCn(tmpITree.leafNodesNum);
}

public static double calculateCn(double n) {
if (n <= 1) {
return 0;
}
return 2.0 * (Math.log(n - 1.0) + 0.5772156649015329) - 2.0 * (n - 1.0) / n;
}
}
Loading