diff --git a/algorithms/kernel/k_nearest_neighbors/bf_knn_impl.i b/algorithms/kernel/k_nearest_neighbors/bf_knn_impl.i index bf72369eb9f..d7bf8027a05 100644 --- a/algorithms/kernel/k_nearest_neighbors/bf_knn_impl.i +++ b/algorithms/kernel/k_nearest_neighbors/bf_knn_impl.i @@ -83,7 +83,7 @@ public: TlsMem tlsIdx(outBlockSize); TlsMem tlsKDistances(inBlockSize * k); TlsMem tlsKIndexes(inBlockSize * k); - TlsMem tlsVoting(nClasses); + TlsMem tlsVoting(nClasses); SafeStatus safeStat; @@ -150,7 +150,7 @@ protected: FPType * trainLabel, const NumericTable * trainTable, const NumericTable * testTable, NumericTable * testLabelTable, NumericTable * indicesTable, NumericTable * distancesTable, TlsMem & tlsDistances, TlsMem & tlsIdx, TlsMem & tlsKDistances, - TlsMem & tlsKIndexes, TlsMem & tlsVoting, size_t nOuterBlocks) + TlsMem & tlsKIndexes, TlsMem & tlsVoting, size_t nOuterBlocks) { const size_t inBlockSize = trainBlockSize; const size_t inRows = nTrain; @@ -265,7 +265,7 @@ protected: DAAL_CHECK_BLOCK_STATUS(testLabelRows); int * testLabel = testLabelRows.get(); - int * voting = tlsVoting.local(); + FPType * voting = tlsVoting.local(); DAAL_CHECK_MALLOC(voting); if (voteWeights == VoteWeights::voteUniform) @@ -351,7 +351,7 @@ protected: } services::Status uniformWeightedVoting(const size_t nClasses, const size_t k, const size_t n, const size_t nTrain, int * indices, - const FPType * trainLabel, int * testLabel, int * classWeights) + const FPType * trainLabel, int * testLabel, FPType * classWeights) { for (size_t i = 0; i < n; ++i) { @@ -380,28 +380,28 @@ protected: } services::Status distanceWeightedVoting(const size_t nClasses, const size_t k, const size_t n, const size_t nTrain, FPType * distances, - int * indices, const FPType * trainLabel, int * testLabel, int * classWeights) + int * indices, const FPType * trainLabel, int * testLabel, FPType * classWeights) { const FPType epsilon = daal::services::internal::EpsilonVal::get(); - bool isContainZero = false; - for (size_t i = 0; i < k * n; ++i) - { - if (distances[i] < epsilon) - { - isContainZero = true; - break; - } - } for (size_t i = 0; i < n; ++i) { + bool isContainZero = false; + for (size_t j = 0; j < k * n; ++j) + { + if (distances[j] < epsilon) + { + isContainZero = true; + break; + } + } for (size_t j = 0; j < nClasses; ++j) { classWeights[j] = 0; } - for (size_t j = 0; j < k; ++j) + if (isContainZero) { - if (isContainZero) + for (size_t j = 0; j < k; ++j) { if (distances[i] < epsilon) { @@ -409,7 +409,10 @@ protected: classWeights[label] += 1; } } - else + } + else + { + for (size_t j = 0; j < k; ++j) { const int label = static_cast(trainLabel[indices[i * k + j]]); classWeights[label] += 1 / distances[i * k + j]; diff --git a/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch.h b/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch.h index 3119f3570b4..f10920f12b0 100755 --- a/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch.h +++ b/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch.h @@ -77,7 +77,7 @@ class KNNClassificationPredictKernel : publi services::Status predict(algorithmFpType * predictedClass, const Heap, cpu> & heap, const NumericTable * labels, size_t k, VoteWeights voteWeights, const NumericTable * modelIndices, data_management::BlockDescriptor & indices, - data_management::BlockDescriptor & distances, size_t index); + data_management::BlockDescriptor & distances, size_t index, const size_t nClasses); }; } // namespace internal diff --git a/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i b/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i index f206a7a84bc..9d0c80e299b 100755 --- a/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i +++ b/algorithms/kernel/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i @@ -282,18 +282,24 @@ Status KNNClassificationPredictKernel::compu typedef daal::internal::Math Math; size_t k; + size_t nClasses; VoteWeights voteWeights = voteUniform; DAAL_UINT64 resultsToEvaluate = classifier::computeClassLabels; { auto par1 = dynamic_cast(par); - if (par1) k = par1->k; + if (par1) + { + k = par1->k; + nClasses = par1->nClasses; + } auto par2 = dynamic_cast(par); if (par2) { k = par2->k; resultsToEvaluate = par2->resultsToEvaluate; + nClasses = par2->nClasses; } const auto par3 = dynamic_cast(par); @@ -302,6 +308,7 @@ Status KNNClassificationPredictKernel::compu k = par3->k; voteWeights = par3->voteWeights; resultsToEvaluate = par3->resultsToEvaluate; + nClasses = par3->nClasses; } if (par1 == NULL && par2 == NULL && par3 == NULL) return Status(ErrorNullParameterNotSupported); @@ -408,7 +415,7 @@ Status KNNClassificationPredictKernel::compu { findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data, isHomogenSOA, soa_arrays); - s = predict(&(dy[i * yColumnCount]), local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i); + s = predict(&(dy[i * yColumnCount]), local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses); DAAL_CHECK_STATUS_THR(s) } @@ -421,7 +428,7 @@ Status KNNClassificationPredictKernel::compu { findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data, isHomogenSOA, soa_arrays); - s = predict(nullptr, local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i); + s = predict(nullptr, local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses); DAAL_CHECK_STATUS_THR(s) } } @@ -599,7 +606,7 @@ template services::Status KNNClassificationPredictKernel::predict( algorithmFpType * predictedClass, const Heap, cpu> & heap, const NumericTable * labels, size_t k, VoteWeights voteWeights, const NumericTable * modelIndices, data_management::BlockDescriptor & indices, - data_management::BlockDescriptor & distances, size_t index) + data_management::BlockDescriptor & distances, size_t index, const size_t nClasses) { typedef daal::internal::Math Math; @@ -661,39 +668,29 @@ services::Status KNNClassificationPredictKernel labelBD; - algorithmFpType * classes = static_cast(daal::services::internal::service_malloc(heapSize)); + algorithmFpType * classes = static_cast(daal::services::internal::service_malloc(heapSize)); + algorithmFpType * classWeights = static_cast(daal::services::internal::service_malloc(nClasses)); + DAAL_CHECK_MALLOC(classWeights); DAAL_CHECK_MALLOC(classes); + + for (size_t i = 0; i < nClasses; ++i) + { + classWeights[i] = 0; + } + for (size_t i = 0; i < heapSize; ++i) { const_cast(labels)->getBlockOfColumnValues(0, heap[i].index, 1, readOnly, labelBD); classes[i] = *(labelBD.getBlockPtr()); const_cast(labels)->releaseBlockOfColumnValues(labelBD); } - daal::algorithms::internal::qSort(heapSize, classes); - algorithmFpType currentClass = classes[0]; - algorithmFpType winnerClass = currentClass; if (voteWeights == voteUniform) { - size_t currentWeight = 1; - size_t winnerWeight = currentWeight; - for (size_t i = 1; i < heapSize; ++i) + for (size_t i = 0; i < heapSize; ++i) { - if (classes[i] == currentClass) - { - if ((++currentWeight) > winnerWeight) - { - winnerWeight = currentWeight; - winnerClass = currentClass; - } - } - else - { - currentWeight = 1; - currentClass = classes[i]; - } + classWeights[(size_t)(classes[i])] += 1; } - *predictedClass = winnerClass; } else { @@ -714,55 +711,37 @@ services::Status KNNClassificationPredictKernel winnerWeight) - { - winnerWeight = currentWeight; - winnerClass = currentClass; + classWeights[(size_t)(classes[i])] += 1; } } - *predictedClass = winnerClass; } else { - algorithmFpType currentWeight = Math::sSqrt(1.0 / heap[0].distance); - algorithmFpType winnerWeight = currentWeight; - for (size_t i = 1; i < heapSize; ++i) + for (size_t i = 0; i < heapSize; ++i) { - if (classes[i] == currentClass) - { - currentWeight += Math::sSqrt(1.0 / heap[i].distance); - } - else - { - currentWeight = Math::sSqrt(1.0 / heap[i].distance); - currentClass = classes[i]; - } - - if (currentWeight > winnerWeight) - { - winnerWeight = currentWeight; - winnerClass = currentClass; - } + classWeights[(size_t)(classes[i])] += Math::sSqrt(1 / heap[i].distance); } - *predictedClass = winnerClass; } } + algorithmFpType maxWeightClass = 0; + algorithmFpType maxWeight = 0; + for (size_t i = 0; i < nClasses; ++i) + { + if (classWeights[i] > maxWeight) + { + maxWeight = classWeights[i]; + maxWeightClass = i; + } + } + *predictedClass = maxWeightClass; + service_free(classes); + service_free(classWeights); classes = nullptr; } diff --git a/examples/cpp/source/k_nearest_neighbors/kdtree_knn_dense_batch.cpp b/examples/cpp/source/k_nearest_neighbors/kdtree_knn_dense_batch.cpp old mode 100644 new mode 100755 index a21c011f41f..476899c9983 --- a/examples/cpp/source/k_nearest_neighbors/kdtree_knn_dense_batch.cpp +++ b/examples/cpp/source/k_nearest_neighbors/kdtree_knn_dense_batch.cpp @@ -107,6 +107,7 @@ void testModel() /* Pass the testing data set and trained model to the algorithm */ algorithm.input.set(classifier::prediction::data, testData); algorithm.input.set(classifier::prediction::model, trainingResult->get(classifier::training::model)); + algorithm.parameter.nClasses = nClasses; /* Compute prediction results */ algorithm.compute(); diff --git a/examples/java/com/intel/daal/examples/knn_classification/KDTreeKNNDenseBatch.java b/examples/java/com/intel/daal/examples/knn_classification/KDTreeKNNDenseBatch.java index f6a33bdf290..d35c45fc98f 100755 --- a/examples/java/com/intel/daal/examples/knn_classification/KDTreeKNNDenseBatch.java +++ b/examples/java/com/intel/daal/examples/knn_classification/KDTreeKNNDenseBatch.java @@ -122,6 +122,7 @@ private static void testModel() { kNearestNeighborsPredict.input.set(NumericTableInputId.data, testData); kNearestNeighborsPredict.input.set(ModelInputId.model, model); + kNearestNeighborsPredict.parameter.setNClasses(nClasses); /* Compute prediction results */ PredictionResult predictionResult = kNearestNeighborsPredict.compute();