|
16 | 16 | use PHPUnit\Framework\Attributes\Test;
|
17 | 17 | use PHPUnit\Framework\TestCase;
|
18 | 18 | use Symfony\AI\Platform\Vector\Vector;
|
| 19 | +use Symfony\AI\Store\Bridge\Postgres\Distance; |
19 | 20 | use Symfony\AI\Store\Bridge\Postgres\Store;
|
20 | 21 | use Symfony\AI\Store\Document\Metadata;
|
21 | 22 | use Symfony\AI\Store\Document\VectorDocument;
|
@@ -156,6 +157,54 @@ public function queryWithoutMinScore(): void
|
156 | 157 | $this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
|
157 | 158 | }
|
158 | 159 |
|
| 160 | + #[Test] |
| 161 | + public function queryChangedDistanceMethodWithoutMinScore(): void |
| 162 | + { |
| 163 | + $pdo = $this->createMock(\PDO::class); |
| 164 | + $statement = $this->createMock(\PDOStatement::class); |
| 165 | + |
| 166 | + $store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine); |
| 167 | + |
| 168 | + $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score |
| 169 | + FROM embeddings_table |
| 170 | +
|
| 171 | + ORDER BY score ASC |
| 172 | + LIMIT 5'; |
| 173 | + |
| 174 | + $pdo->expects($this->once()) |
| 175 | + ->method('prepare') |
| 176 | + ->with($this->callback(function ($sql) use ($expectedSql) { |
| 177 | + return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); |
| 178 | + })) |
| 179 | + ->willReturn($statement); |
| 180 | + |
| 181 | + $uuid = Uuid::v4(); |
| 182 | + |
| 183 | + $statement->expects($this->once()) |
| 184 | + ->method('execute') |
| 185 | + ->with(['embedding' => '[0.1,0.2,0.3]']); |
| 186 | + |
| 187 | + $statement->expects($this->once()) |
| 188 | + ->method('fetchAll') |
| 189 | + ->with(\PDO::FETCH_ASSOC) |
| 190 | + ->willReturn([ |
| 191 | + [ |
| 192 | + 'id' => $uuid->toRfc4122(), |
| 193 | + 'embedding' => '[0.1,0.2,0.3]', |
| 194 | + 'metadata' => json_encode(['title' => 'Test Document']), |
| 195 | + 'score' => 0.95, |
| 196 | + ], |
| 197 | + ]); |
| 198 | + |
| 199 | + $results = $store->query(new Vector([0.1, 0.2, 0.3])); |
| 200 | + |
| 201 | + $this->assertCount(1, $results); |
| 202 | + $this->assertInstanceOf(VectorDocument::class, $results[0]); |
| 203 | + $this->assertEquals($uuid, $results[0]->id); |
| 204 | + $this->assertSame(0.95, $results[0]->score); |
| 205 | + $this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy()); |
| 206 | + } |
| 207 | + |
159 | 208 | #[Test]
|
160 | 209 | public function queryWithMinScore(): void
|
161 | 210 | {
|
@@ -194,6 +243,44 @@ public function queryWithMinScore(): void
|
194 | 243 | $this->assertCount(0, $results);
|
195 | 244 | }
|
196 | 245 |
|
| 246 | + #[Test] |
| 247 | + public function queryWithMinScoreAndDifferentDistance(): void |
| 248 | + { |
| 249 | + $pdo = $this->createMock(\PDO::class); |
| 250 | + $statement = $this->createMock(\PDOStatement::class); |
| 251 | + |
| 252 | + $store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine); |
| 253 | + |
| 254 | + $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score |
| 255 | + FROM embeddings_table |
| 256 | + WHERE (embedding <=> :embedding) >= :minScore |
| 257 | + ORDER BY score ASC |
| 258 | + LIMIT 5'; |
| 259 | + |
| 260 | + $pdo->expects($this->once()) |
| 261 | + ->method('prepare') |
| 262 | + ->with($this->callback(function ($sql) use ($expectedSql) { |
| 263 | + return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); |
| 264 | + })) |
| 265 | + ->willReturn($statement); |
| 266 | + |
| 267 | + $statement->expects($this->once()) |
| 268 | + ->method('execute') |
| 269 | + ->with([ |
| 270 | + 'embedding' => '[0.1,0.2,0.3]', |
| 271 | + 'minScore' => 0.8, |
| 272 | + ]); |
| 273 | + |
| 274 | + $statement->expects($this->once()) |
| 275 | + ->method('fetchAll') |
| 276 | + ->with(\PDO::FETCH_ASSOC) |
| 277 | + ->willReturn([]); |
| 278 | + |
| 279 | + $results = $store->query(new Vector([0.1, 0.2, 0.3]), [], 0.8); |
| 280 | + |
| 281 | + $this->assertCount(0, $results); |
| 282 | + } |
| 283 | + |
197 | 284 | #[Test]
|
198 | 285 | public function queryWithCustomLimit(): void
|
199 | 286 | {
|
|
0 commit comments