Skip to content

Commit 8c7735a

Browse files
committed
[Store][Postgres] allow store initialization with utilized distance
1 parent fb18eb1 commit 8c7735a

File tree

3 files changed

+148
-13
lines changed

3 files changed

+148
-13
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Store\Bridge\Postgres;
13+
14+
use OskarStark\Enum\Trait\Comparable;
15+
16+
/**
17+
* @author Denis Zunke <[email protected]>
18+
*/
19+
enum Distance: string
20+
{
21+
use Comparable;
22+
23+
case Cosine = 'cosine';
24+
case InnerProduct = 'inner_product';
25+
case L1 = 'l1';
26+
case L2 = 'l2';
27+
28+
public function getComparisonSign(): string
29+
{
30+
return match ($this) {
31+
self::Cosine => '<=>',
32+
self::InnerProduct => '<#>',
33+
self::L1 => '<+>',
34+
self::L2 => '<->',
35+
};
36+
}
37+
}

src/store/src/Bridge/Postgres/Store.php

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,32 @@ public function __construct(
3434
private \PDO $connection,
3535
private string $tableName,
3636
private string $vectorFieldName = 'embedding',
37+
private Distance $distance = Distance::L2,
3738
) {
3839
}
3940

40-
public static function fromPdo(\PDO $connection, string $tableName, string $vectorFieldName = 'embedding'): self
41-
{
42-
return new self($connection, $tableName, $vectorFieldName);
41+
public static function fromPdo(
42+
\PDO $connection,
43+
string $tableName,
44+
string $vectorFieldName = 'embedding',
45+
Distance $distance = Distance::L2,
46+
): self {
47+
return new self($connection, $tableName, $vectorFieldName, $distance);
4348
}
4449

45-
public static function fromDbal(Connection $connection, string $tableName, string $vectorFieldName = 'embedding'): self
46-
{
50+
public static function fromDbal(
51+
Connection $connection,
52+
string $tableName,
53+
string $vectorFieldName = 'embedding',
54+
Distance $distance = Distance::L2,
55+
): self {
4756
$pdo = $connection->getNativeConnection();
4857

4958
if (!$pdo instanceof \PDO) {
5059
throw new InvalidArgumentException('Only DBAL connections using PDO driver are supported.');
5160
}
5261

53-
return self::fromPdo($pdo, $tableName, $vectorFieldName);
62+
return self::fromPdo($pdo, $tableName, $vectorFieldName, $distance);
5463
}
5564

5665
public function add(VectorDocument ...$documents): void
@@ -84,16 +93,18 @@ public function add(VectorDocument ...$documents): void
8493
*/
8594
public function query(Vector $vector, array $options = [], ?float $minScore = null): array
8695
{
87-
$sql = \sprintf(
88-
'SELECT id, %s AS embedding, metadata, (%s <-> :embedding) AS score
89-
FROM %s
90-
%s
91-
ORDER BY score ASC
92-
LIMIT %d',
96+
$sql = \sprintf(<<<SQL
97+
SELECT id, %s AS embedding, metadata, (%s %s :embedding) AS score
98+
FROM %s
99+
%s
100+
ORDER BY score ASC
101+
LIMIT %d
102+
SQL,
93103
$this->vectorFieldName,
94104
$this->vectorFieldName,
105+
$this->distance->getComparisonSign(),
95106
$this->tableName,
96-
null !== $minScore ? "WHERE ({$this->vectorFieldName} <-> :embedding) >= :minScore" : '',
107+
null !== $minScore ? "WHERE ({$this->vectorFieldName} {$this->distance->getComparisonSign()} :embedding) >= :minScore" : '',
97108
$options['limit'] ?? 5,
98109
);
99110
$statement = $this->connection->prepare($sql);

src/store/tests/Bridge/Postgres/StoreTest.php

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
use PHPUnit\Framework\Attributes\Test;
1717
use PHPUnit\Framework\TestCase;
1818
use Symfony\AI\Platform\Vector\Vector;
19+
use Symfony\AI\Store\Bridge\Postgres\Distance;
1920
use Symfony\AI\Store\Bridge\Postgres\Store;
2021
use Symfony\AI\Store\Document\Metadata;
2122
use Symfony\AI\Store\Document\VectorDocument;
@@ -156,6 +157,54 @@ public function queryWithoutMinScore(): void
156157
$this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
157158
}
158159

160+
#[Test]
161+
public function queryChangedDistanceMethodWithoutMinScore(): void
162+
{
163+
$pdo = $this->createMock(\PDO::class);
164+
$statement = $this->createMock(\PDOStatement::class);
165+
166+
$store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine);
167+
168+
$expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score
169+
FROM embeddings_table
170+
171+
ORDER BY score ASC
172+
LIMIT 5';
173+
174+
$pdo->expects($this->once())
175+
->method('prepare')
176+
->with($this->callback(function ($sql) use ($expectedSql) {
177+
return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql);
178+
}))
179+
->willReturn($statement);
180+
181+
$uuid = Uuid::v4();
182+
183+
$statement->expects($this->once())
184+
->method('execute')
185+
->with(['embedding' => '[0.1,0.2,0.3]']);
186+
187+
$statement->expects($this->once())
188+
->method('fetchAll')
189+
->with(\PDO::FETCH_ASSOC)
190+
->willReturn([
191+
[
192+
'id' => $uuid->toRfc4122(),
193+
'embedding' => '[0.1,0.2,0.3]',
194+
'metadata' => json_encode(['title' => 'Test Document']),
195+
'score' => 0.95,
196+
],
197+
]);
198+
199+
$results = $store->query(new Vector([0.1, 0.2, 0.3]));
200+
201+
$this->assertCount(1, $results);
202+
$this->assertInstanceOf(VectorDocument::class, $results[0]);
203+
$this->assertEquals($uuid, $results[0]->id);
204+
$this->assertSame(0.95, $results[0]->score);
205+
$this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
206+
}
207+
159208
#[Test]
160209
public function queryWithMinScore(): void
161210
{
@@ -194,6 +243,44 @@ public function queryWithMinScore(): void
194243
$this->assertCount(0, $results);
195244
}
196245

246+
#[Test]
247+
public function queryWithMinScoreAndDifferentDistance(): void
248+
{
249+
$pdo = $this->createMock(\PDO::class);
250+
$statement = $this->createMock(\PDOStatement::class);
251+
252+
$store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine);
253+
254+
$expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score
255+
FROM embeddings_table
256+
WHERE (embedding <=> :embedding) >= :minScore
257+
ORDER BY score ASC
258+
LIMIT 5';
259+
260+
$pdo->expects($this->once())
261+
->method('prepare')
262+
->with($this->callback(function ($sql) use ($expectedSql) {
263+
return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql);
264+
}))
265+
->willReturn($statement);
266+
267+
$statement->expects($this->once())
268+
->method('execute')
269+
->with([
270+
'embedding' => '[0.1,0.2,0.3]',
271+
'minScore' => 0.8,
272+
]);
273+
274+
$statement->expects($this->once())
275+
->method('fetchAll')
276+
->with(\PDO::FETCH_ASSOC)
277+
->willReturn([]);
278+
279+
$results = $store->query(new Vector([0.1, 0.2, 0.3]), [], 0.8);
280+
281+
$this->assertCount(0, $results);
282+
}
283+
197284
#[Test]
198285
public function queryWithCustomLimit(): void
199286
{

0 commit comments

Comments
 (0)