@@ -344,37 +344,42 @@ public void run() {
344
344
345
345
float entryValue = scaling * sqrtSparsity / sqrtEmbeddingDimension ;
346
346
random .reseed (randomSeed ^ graph .toOriginalNodeId (nodeId ));
347
- var randomVector = computeRandomVector (nodeId , random , entryValue );
347
+ var randomVector = computeRandomVector (nodeId , random , entryValue , scaling );
348
348
embeddingB .set (nodeId , randomVector );
349
349
embeddingA .set (nodeId , new float [embeddingDimension ]);
350
350
});
351
351
progressTracker .logProgress (partition .nodeCount ());
352
352
}
353
353
354
- private float [] computeRandomVector (long nodeId , Random random , float entryValue ) {
354
+ private float [] computeRandomVector (long nodeId , Random random , float entryValue , float scaling ) {
355
355
var randomVector = new float [embeddingDimension ];
356
356
for (int i = 0 ; i < baseEmbeddingDimension ; i ++) {
357
357
randomVector [i ] = computeRandomEntry (random , entryValue );
358
358
}
359
359
360
360
propertyVectorAdder .setRandomVector (randomVector );
361
+ propertyVectorAdder .setScaling (scaling );
361
362
FeatureExtraction .extract (nodeId , -1 , featureExtractors , propertyVectorAdder );
362
363
363
364
return randomVector ;
364
365
}
365
366
366
367
private class PropertyVectorAdder implements FeatureConsumer {
367
368
private float [] randomVector ;
369
+ private float scaling = 1.0f ;
368
370
369
371
void setRandomVector (float [] randomVector ) {
370
372
this .randomVector = randomVector ;
371
373
}
374
+ void setScaling (float scaling ) {
375
+ this .scaling = scaling ;
376
+ }
372
377
373
378
@ Override
374
379
public void acceptScalar (long ignored , int offset , double value ) {
375
380
float floatValue = (float ) value ;
376
381
for (int i = baseEmbeddingDimension ; i < embeddingDimension ; i ++) {
377
- randomVector [i ] += floatValue * propertyVectors [offset ][i - baseEmbeddingDimension ];
382
+ randomVector [i ] += scaling * floatValue * propertyVectors [offset ][i - baseEmbeddingDimension ];
378
383
}
379
384
}
380
385
@@ -384,7 +389,7 @@ public void acceptArray(long ignored, int offset, double[] values) {
384
389
var value = (float ) values [j ];
385
390
float [] propertyVector = propertyVectors [offset + j ];
386
391
for (int i = baseEmbeddingDimension ; i < embeddingDimension ; i ++) {
387
- randomVector [i ] += value * propertyVector [i - baseEmbeddingDimension ];
392
+ randomVector [i ] += scaling * value * propertyVector [i - baseEmbeddingDimension ];
388
393
}
389
394
}
390
395
}
0 commit comments