Skip to content

Commit 5679785

Browse files
committed
Improve floating-point precision and new getWords method
1 parent b8bd530 commit 5679785

File tree

3 files changed

+93
-46
lines changed

3 files changed

+93
-46
lines changed

.github/workflows/phpunit.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
strategy:
99
fail-fast: false
1010
matrix:
11-
php: [ 8.2, 8.1]
11+
php: [8.3, 8.2, 8.1]
1212
os: [ ubuntu-latest, windows-latest ]
1313
laravel: [ 10.* ]
1414
dependency-version: [ prefer-lowest, prefer-stable ]

src/Classifier.php

+38-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
namespace AssistedMindfulness\NaiveBayes;
44

55
use Brick\Math\BigDecimal;
6+
use Illuminate\Support\Arr;
67
use Illuminate\Support\Collection;
78
use Illuminate\Support\Str;
89

@@ -26,6 +27,8 @@ class Classifier
2627
private bool $uneven = false;
2728

2829
/**
30+
* Sets a custom tokenizer function for tokenizing input strings.
31+
*
2932
* @param callable(string): array<int, string> $tokenizer
3033
*/
3134
public function setTokenizer(callable $tokenizer): void
@@ -34,7 +37,18 @@ public function setTokenizer(callable $tokenizer): void
3437
}
3538

3639
/**
37-
* @return Collection<int, string>
40+
* Retrieves the word counts associated with a specific type or all types if no key is provided.
41+
*/
42+
public function getWords(int|null|string $type = null):array
43+
{
44+
return Arr::get($this->words, $type);
45+
}
46+
47+
/**
48+
* Tokenizes a given string into individual words.
49+
*
50+
* @param string $string The input string to tokenize.
51+
* @return Collection<int, string> A collection of tokens.
3852
*/
3953
public function tokenize(string $string): Collection
4054
{
@@ -51,7 +65,7 @@ public function tokenize(string $string): Collection
5165
}
5266

5367
/**
54-
* @return $this
68+
* Learns from a given statement by updating word and document counts.
5569
*/
5670
public function learn(string $statement, string $type): self
5771
{
@@ -65,7 +79,7 @@ public function learn(string $statement, string $type): self
6579
}
6680

6781
/**
68-
* @return Collection<string, string>
82+
* Guesses the type of a given statement using Naive Bayes classification.
6983
*/
7084
public function guess(string $statement): Collection
7185
{
@@ -81,17 +95,22 @@ public function guess(string $statement): Collection
8195

8296
return (string) BigDecimal::of($likelihood);
8397
})
84-
->sortDesc();
98+
->sort(function ($a, $b) {
99+
return BigDecimal::of($a)->compareTo($b);
100+
});
85101
}
86102

103+
/**
104+
* Retrieves the most likely type for a given statement.
105+
*/
87106
public function most(string $statement): string
88107
{
89108
/** @var string */
90-
return $this->guess($statement)->keys()->first();
109+
return $this->guess($statement)->keys()->last();
91110
}
92111

93112
/**
94-
* @return self
113+
* Toggles the "uneven" mode which adjusts probability calculation for document types.
95114
*/
96115
public function uneven(bool $enabled = true): self
97116
{
@@ -105,7 +124,7 @@ public function uneven(bool $enabled = true): self
105124
*/
106125
private function incrementType(string $type): void
107126
{
108-
if (! isset($this->documents[$type])) {
127+
if (!isset($this->documents[$type])) {
109128
$this->documents[$type] = 0;
110129
}
111130

@@ -117,15 +136,20 @@ private function incrementType(string $type): void
117136
*/
118137
private function incrementWord(string $type, string $word): void
119138
{
120-
if (! isset($this->words[$type][$word])) {
139+
if (!isset($this->words[$type][$word])) {
121140
$this->words[$type][$word] = 0;
122141
}
123142

124143
$this->words[$type][$word]++;
125144
}
126145

127146
/**
128-
* @return float|int
147+
* Calculates the conditional probability of a word occurring in a type.
148+
*
149+
* @param string $word The word to calculate probability for.
150+
* @param string $type The type to calculate probability in.
151+
*
152+
* @return float|int The calculated probability.
129153
*/
130154
private function p(string $word, string $type)
131155
{
@@ -135,7 +159,11 @@ private function p(string $word, string $type)
135159
}
136160

137161
/**
138-
* @return float|int
162+
* Calculates the prior probability of a type.
163+
*
164+
* @param string $type The type to calculate probability for.
165+
*
166+
* @return float|int The calculated probability.
139167
*/
140168
private function pTotal(string $type)
141169
{

tests/ClassifierTest.php

+54-35
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,22 @@ public function testMostClassifier(): void
5858
}
5959

6060
/*
61-
public function testTextDataSetsClassifier(): void
62-
{
63-
$classifier = new Classifier();
64-
65-
$classifier
66-
->learn(file_get_contents(__DIR__ . '/datasets/positive-words.txt'), 'positive')
67-
->learn(file_get_contents(__DIR__ . '/datasets/negative-words.txt'), 'negative');
61+
public function testTextDataSetsClassifier(): void
62+
{
63+
$classifier = new Classifier();
6864
65+
$classifier
66+
->learn(file_get_contents(__DIR__ . '/datasets/positive-words.txt'), 'positive')
67+
->learn(file_get_contents(__DIR__ . '/datasets/negative-words.txt'), 'negative');
6968
70-
dd(
71-
$classifier->guess("Service outside is horrible. The waiter didn't even get our appetizers right. It came out after our mains. And did not fix the heater when it was out. Service is appalling.")
72-
);
69+
// Test for a sentence containing positive words
70+
$this->assertSame('positive', $classifier->most('The movie was absolutely fantastic and uplifting.'));
7371
74-
$this->assertSame('positive', $classifier->most('I love sunny days'));
75-
$this->assertSame('negative', $classifier->most('I hate rain'));
76-
}
72+
// Test for a sentence containing negative words
73+
//$this->assertSame('negative', $classifier->most('The service at the restaurant was terrible, and the food was awful.'));}
7774
*/
7875

76+
7977
public function testTextClassifier(): void
8078
{
8179
$classifier = new Classifier();
@@ -126,22 +124,43 @@ public function testCorrectnessLearn(): void
126124
$this->assertSame('positive', $classifier->most('awesome, cool, amazing Yeah.'));
127125
}
128126

129-
/*
130-
public function testCategorizesChineseCorrectly(): void
131-
{
132-
$classifier = new Classifier();
133-
134-
$classifier
135-
->learn('Chinese Beijing Chinese', 'chinese')
136-
->learn('Chinese Chinese Shanghai', 'chinese')
137-
->learn('Chinese Macao', 'chinese')
138-
->learn('Tokyo Japan Chinese', 'japanese')
139-
->learn('Chinese Macao Beijing Chinese Tokyo Japan', 'chinese');
140-
141-
$this->assertSame('chinese', $classifier->most('Tokyo Japan Chinese Chinese Chinese Chinese'));
142-
$this->assertSame('japanese', $classifier->most('Tokyo'));
143-
}
144-
*/
127+
public function testWordCountCorrectly(): void
128+
{
129+
$classifier = new Classifier();
130+
131+
$classifier
132+
->uneven(true)
133+
->learn('Chinese Beijing Chinese', 'chinese')
134+
->learn('Chinese Chinese Shanghai', 'chinese')
135+
->learn('Chinese Macao', 'chinese');
136+
137+
// teach it how to identify the `japanese` category
138+
$classifier->learn('Tokyo Japan Chinese', 'japanese');
139+
140+
// make sure it learned the `chinese` category correctly
141+
$chineseFrequencyCount = $classifier->getWords('chinese');
142+
143+
$this->assertTrue($chineseFrequencyCount['chinese'] === 5);
144+
$this->assertTrue($chineseFrequencyCount['beijing'] === 1);
145+
$this->assertTrue($chineseFrequencyCount['shanghai'] === 1);
146+
$this->assertTrue($chineseFrequencyCount['macao'] === 1);
147+
148+
149+
// make sure it learned the `japanese` category correctly
150+
$japaneseFrequencyCount = $classifier->getWords('japanese');
151+
152+
$this->assertTrue($japaneseFrequencyCount['tokyo'] === 1);
153+
$this->assertTrue($japaneseFrequencyCount['japan'] === 1);
154+
$this->assertTrue($japaneseFrequencyCount['chinese'] === 1);
155+
156+
157+
// Verify that the classifier correctly categorizes a new document
158+
// Due to the higher weight assigned to the word 'Tokyo' in the training data for the 'japanese' category,
159+
// the classifier is expected to classify the document 'Chinese Macao Tokyo' as 'japanese',
160+
// despite the presence of the words 'Chinese' and 'Macao'.
161+
$this->assertSame('japanese', $classifier->most('Chinese Macao Tokyo'));
162+
}
163+
145164

146165
public function testCategorizesSimpleCorrectly(): void
147166
{
@@ -185,13 +204,13 @@ public function testSimpleSpam(): void
185204
$classifier = new Classifier();
186205

187206
$classifier
188-
->learn('Some spam document', 'spam')
189-
->learn('Another spam document', 'spam')
190-
->learn('Some ham document', 'ham')
191-
->learn('Another ham document', 'ham');
207+
->learn('Learn how to grow your business with these proven strategies', 'ham')
208+
->learn('Unlock the secrets of successful investing in our latest guide', 'ham')
209+
->learn('Get exclusive access to limited-time discounts and offers', 'spam')
210+
->learn('Earn money from home with our easy-to-follow program', 'spam');
192211

193212

194-
$this->assertSame('ham', $classifier->most('Some ham document'));
195-
$this->assertSame('spam', $classifier->most('Some ham spam'));
213+
$this->assertEquals('ham', $classifier->most('Discover the art of effective communication in our workshop'));
214+
$this->assertEquals('spam', $classifier->most('Start making money from home today with our revolutionary system'));
196215
}
197216
}

0 commit comments

Comments
 (0)