Skip to content

Commit 9e3decf

Browse files
Merge pull request #10186 from breakanalysis/fastrp-apply-degree-normalization-to-property-part
Apply degree normalization to property part of initial random vectors in FastRP
2 parents b577920 + decc31d commit 9e3decf

File tree

1 file changed

+9
-4
lines changed
  • algo/src/main/java/org/neo4j/gds/embeddings/fastrp

1 file changed

+9
-4
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/fastrp/FastRP.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,37 +344,42 @@ public void run() {
344344

345345
float entryValue = scaling * sqrtSparsity / sqrtEmbeddingDimension;
346346
random.reseed(randomSeed ^ graph.toOriginalNodeId(nodeId));
347-
var randomVector = computeRandomVector(nodeId, random, entryValue);
347+
var randomVector = computeRandomVector(nodeId, random, entryValue, scaling);
348348
embeddingB.set(nodeId, randomVector);
349349
embeddingA.set(nodeId, new float[embeddingDimension]);
350350
});
351351
progressTracker.logProgress(partition.nodeCount());
352352
}
353353

354-
private float[] computeRandomVector(long nodeId, Random random, float entryValue) {
354+
private float[] computeRandomVector(long nodeId, Random random, float entryValue, float scaling) {
355355
var randomVector = new float[embeddingDimension];
356356
for (int i = 0; i < baseEmbeddingDimension; i++) {
357357
randomVector[i] = computeRandomEntry(random, entryValue);
358358
}
359359

360360
propertyVectorAdder.setRandomVector(randomVector);
361+
propertyVectorAdder.setScaling(scaling);
361362
FeatureExtraction.extract(nodeId, -1, featureExtractors, propertyVectorAdder);
362363

363364
return randomVector;
364365
}
365366

366367
private class PropertyVectorAdder implements FeatureConsumer {
367368
private float[] randomVector;
369+
private float scaling = 1.0f;
368370

369371
void setRandomVector(float[] randomVector) {
370372
this.randomVector = randomVector;
371373
}
374+
void setScaling(float scaling) {
375+
this.scaling = scaling;
376+
}
372377

373378
@Override
374379
public void acceptScalar(long ignored, int offset, double value) {
375380
float floatValue = (float) value;
376381
for (int i = baseEmbeddingDimension; i < embeddingDimension; i++) {
377-
randomVector[i] += floatValue * propertyVectors[offset][i - baseEmbeddingDimension];
382+
randomVector[i] += scaling * floatValue * propertyVectors[offset][i - baseEmbeddingDimension];
378383
}
379384
}
380385

@@ -384,7 +389,7 @@ public void acceptArray(long ignored, int offset, double[] values) {
384389
var value = (float) values[j];
385390
float[] propertyVector = propertyVectors[offset + j];
386391
for (int i = baseEmbeddingDimension; i < embeddingDimension; i++) {
387-
randomVector[i] += value * propertyVector[i - baseEmbeddingDimension];
392+
randomVector[i] += scaling * value * propertyVector[i - baseEmbeddingDimension];
388393
}
389394
}
390395
}

0 commit comments

Comments
 (0)