Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,78 @@

package org.apache.geaflow.ai.index.vector;

import java.util.Objects;

public class MagnitudeVector implements IVector {

private final double magnitude;

public MagnitudeVector() {
this.magnitude = 0.0;
}

public MagnitudeVector(double magnitude) {
this.magnitude = magnitude;
}

public double getMagnitude() {
return magnitude;
}

@Override
public double match(IVector other) {
return 0;
if (!(other instanceof MagnitudeVector)) {
return 0.0;
}

MagnitudeVector otherVec = (MagnitudeVector) other;
double otherMagnitude = otherVec.magnitude;

// Both zero -> perfect match
if (this.magnitude == 0.0 && otherMagnitude == 0.0) {
return 1.0;
}

// One is zero, other is not -> no match
if (this.magnitude == 0.0 || otherMagnitude == 0.0) {
return 0.0;
}

// Compute normalized difference
double diff = Math.abs(this.magnitude - otherMagnitude);
double max = Math.max(Math.abs(this.magnitude), Math.abs(otherMagnitude));

if (max == 0.0) {
return 1.0;
}

return 1.0 - (diff / max);
}

@Override
public VectorType getType() {
return VectorType.MagnitudeVector;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MagnitudeVector that = (MagnitudeVector) o;
return Double.compare(that.magnitude, magnitude) == 0;
}

@Override
public int hashCode() {
return Objects.hash(magnitude);
}

@Override
public String toString() {
return "MagnitudeVector{}";
return "MagnitudeVector{magnitude=" + magnitude + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

package org.apache.geaflow.ai.index.vector;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;

public class TraversalVector implements IVector {

private final String[] vec;
Expand All @@ -32,25 +37,86 @@ public TraversalVector(String... vec) {

@Override
public double match(IVector other) {
return 0;
if (!(other instanceof TraversalVector)) {
return 0.0;
}

TraversalVector otherVec = (TraversalVector) other;

// Exact match: identical triple sequence
if (Arrays.equals(this.vec, otherVec.vec)) {
return 1.0;
}

// Convert triples to set for efficient comparison
Set<String> thisTriples = getTriplesSet();
Set<String> otherTriples = otherVec.getTriplesSet();

// Check for subgraph containment (this is contained in other)
if (otherTriples.containsAll(thisTriples)) {
return 0.8;
}

// Compute partial overlap using Jaccard similarity
Set<String> intersection = new HashSet<>(thisTriples);
intersection.retainAll(otherTriples);

if (intersection.isEmpty()) {
return 0.0;
}

Set<String> union = new HashSet<>(thisTriples);
union.addAll(otherTriples);

return (double) intersection.size() / union.size();
}

/**
* Converts the array of triples into a Set of string representations.
* Each triple is represented as "src|edge|dst".
*/
private Set<String> getTriplesSet() {
Set<String> triples = new HashSet<>();
for (int i = 0; i < vec.length; i += 3) {
String triple = vec[i] + "|" + vec[i + 1] + "|" + vec[i + 2];
triples.add(triple);
}
return triples;
}

@Override
public VectorType getType() {
return VectorType.TraversalVector;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TraversalVector that = (TraversalVector) o;
return Arrays.equals(vec, that.vec);
}

@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(vec));
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder("TraversalVector{vec=");
for (int i = 0; i < vec.length; i++) {
if (i > 0) {
sb.append(i % 3 == 0 ? "; " : "-");
}
sb.append(vec[i]);
if (i % 3 == 2) {
sb.append(">");
}
sb.append(vec[i]);
}
return sb.append('}').toString();
}
Expand Down
187 changes: 187 additions & 0 deletions geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import org.apache.geaflow.ai.index.EntityAttributeIndexStore;
import org.apache.geaflow.ai.index.IndexStore;
import org.apache.geaflow.ai.index.vector.EmbeddingVector;
import org.apache.geaflow.ai.index.vector.IVector;
import org.apache.geaflow.ai.index.vector.KeywordVector;
import org.apache.geaflow.ai.index.vector.MagnitudeVector;
import org.apache.geaflow.ai.index.vector.TraversalVector;
import org.apache.geaflow.ai.index.vector.VectorType;
import org.apache.geaflow.ai.search.VectorSearch;
import org.apache.geaflow.ai.verbalization.Context;
import org.apache.geaflow.ai.verbalization.SubgraphSemanticPromptFunction;
Expand All @@ -38,6 +40,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;

public class GraphMemoryTest {

private static final Logger LOGGER = LoggerFactory.getLogger(GraphMemoryTest.class);
Expand All @@ -52,6 +57,188 @@ public void testVectorSearch() {
LOGGER.info(String.valueOf(search));
}

// ========== MagnitudeVector Tests ==========

@Test
public void testMagnitudeVectorConstructorAndGetter() {
MagnitudeVector vector = new MagnitudeVector(0.85);
assertEquals(vector.getMagnitude(), 0.85, 0.0001);
}

@Test
public void testMagnitudeVectorMatchExactSameValue() {
MagnitudeVector v1 = new MagnitudeVector(5.0);
MagnitudeVector v2 = new MagnitudeVector(5.0);

assertEquals(v1.match(v2), 1.0, 0.0001);
}

@Test
public void testMagnitudeVectorMatchDifferentValues() {
MagnitudeVector v1 = new MagnitudeVector(10.0);
MagnitudeVector v2 = new MagnitudeVector(5.0);

// Expected: 1 - |10-5|/max(10,5) = 1 - 5/10 = 0.5
assertEquals(v1.match(v2), 0.5, 0.0001);
}

@Test
public void testMagnitudeVectorMatchWithIncompatibleType() {
MagnitudeVector v1 = new MagnitudeVector(5.0);
IVector incompatibleVector = new IVector() {
@Override
public double match(IVector other) {
return 0;
}

@Override
public VectorType getType() {
return null;
}
};

assertEquals(v1.match(incompatibleVector), 0.0, 0.0001);
}

@Test
public void testMagnitudeVectorEqualsAndHashCode() {
MagnitudeVector v1 = new MagnitudeVector(5.0);
MagnitudeVector v2 = new MagnitudeVector(5.0);
MagnitudeVector v3 = new MagnitudeVector(10.0);

assertEquals(v1, v2);
assertEquals(v1.hashCode(), v2.hashCode());
assertNotEquals(v1, v3);
}

@Test
public void testMagnitudeVectorToString() {
MagnitudeVector vector = new MagnitudeVector(0.75);
String str = vector.toString();

assertEquals(str, "MagnitudeVector{magnitude=0.75}");
}

@Test
public void testMagnitudeVectorGetType() {
MagnitudeVector vector = new MagnitudeVector(1.0);
assertEquals(vector.getType(), VectorType.MagnitudeVector);
}

// ========== TraversalVector Tests ==========

@Test
public void testTraversalVectorConstructorValidInput() {
new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "knows", "Charlie"
);

// Should not throw exception
}

@Test
public void testTraversalVectorConstructorInvalidInput() {
// Should throw exception if not multiple of 3
org.junit.jupiter.api.Assertions.assertThrows(RuntimeException.class, () ->
new TraversalVector("Alice", "knows", "Bob", "Bob", "knows")
);
}

@Test
public void testTraversalVectorMatchExactSamePath() {
TraversalVector v1 = new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "knows", "Charlie"
);
TraversalVector v2 = new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "knows", "Charlie"
);

assertEquals(v1.match(v2), 1.0, 0.0001);
}

@Test
public void testTraversalVectorMatchSubgraphContainment() {
// v1 is contained within v2
TraversalVector v1 = new TraversalVector(
"Bob", "knows", "Charlie"
);
TraversalVector v2 = new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "knows", "Charlie",
"Charlie", "knows", "Dave"
);

// v1 is subgraph of v2, so score should be 0.8
assertEquals(v1.match(v2), 0.8, 0.0001);
}

@Test
public void testTraversalVectorMatchPartialOverlap() {
// Two vectors sharing one common edge
TraversalVector v1 = new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "likes", "Charlie"
);
TraversalVector v2 = new TraversalVector(
"Bob", "knows", "Charlie",
"Alice", "knows", "Bob"
);

// One common edge out of 3 unique edges total = 1/3
double expected = 1.0 / 3.0;
assertEquals(v1.match(v2), expected, 0.0001);
}

@Test
public void testTraversalVectorMatchNoOverlap() {
TraversalVector v1 = new TraversalVector(
"Alice", "knows", "Bob"
);
TraversalVector v2 = new TraversalVector(
"Charlie", "knows", "Dave"
);

assertEquals(v1.match(v2), 0.0, 0.0001);
}

@Test
public void testTraversalVectorEqualsAndHashCode() {
TraversalVector v1 = new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "knows", "Charlie"
);
TraversalVector v2 = new TraversalVector(
"Alice", "knows", "Bob",
"Bob", "knows", "Charlie"
);
TraversalVector v3 = new TraversalVector(
"Bob", "knows", "Charlie"
);

assertEquals(v1, v2);
assertEquals(v1.hashCode(), v2.hashCode());
assertNotEquals(v1, v3);
}

@Test
public void testTraversalVectorToString() {
TraversalVector vector = new TraversalVector(
"Alice", "knows", "Bob"
);
String str = vector.toString();

assertEquals(str, "TraversalVector{vec=Alice-knows->Bob}");
}

@Test
public void testTraversalVectorGetType() {
TraversalVector vector = new TraversalVector("Alice", "knows", "Bob");
assertEquals(vector.getType(), VectorType.TraversalVector);
}

@Test
public void testEmptyMainPipeline() {
GraphMemoryServer server = new GraphMemoryServer();
Expand Down