Skip to content

Commit

Permalink
implement item to item recommendation (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 18, 2025
1 parent a2ec23f commit 8d6638b
Show file tree
Hide file tree
Showing 21 changed files with 670 additions and 80 deletions.
4 changes: 2 additions & 2 deletions common/ann/ann.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package search
package ann

import (
"github.com/samber/lo"
Expand All @@ -21,5 +21,5 @@ import (
type Index interface {
Add(v []float32) (int, error)
SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error)
SearchVector(q []float32, k int, prune0 bool) ([]lo.Tuple2[int, float32], error)
SearchVector(q []float32, k int, prune0 bool) []lo.Tuple2[int, float32]
}
97 changes: 90 additions & 7 deletions common/ann/ann_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package search
package ann

import (
"bufio"
mapset "github.com/deckarep/golang-set/v2"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/common/dataset"
"github.com/zhenghaoz/gorse/common/datautil"
"github.com/zhenghaoz/gorse/common/util"
"os"
"path/filepath"
Expand Down Expand Up @@ -57,7 +57,7 @@ type MNIST struct {

func mnist() (*MNIST, error) {
// Download and unzip dataset
path, err := dataset.DownloadAndUnzip("mnist")
path, err := datautil.DownloadAndUnzip("mnist")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -136,13 +136,96 @@ func TestMNIST(t *testing.T) {
// Test search
r := 0.0
for _, image := range dat.TestImages[:testSize] {
gt, err := bf.SearchVector(image, 100, false)
assert.NoError(t, err)
gt := bf.SearchVector(image, 100, false)
assert.Len(t, gt, 100)
scores, err := hnsw.SearchVector(image, 100, false)
assert.NoError(t, err)
scores := hnsw.SearchVector(image, 100, false)
assert.Len(t, scores, 100)
r += recall(gt, scores)
}
r /= float64(testSize)
assert.Greater(t, r, 0.99)
}

func movieLens() ([][]int, error) {
// Download and unzip dataset
path, err := datautil.DownloadAndUnzip("ml-1m")
if err != nil {
return nil, err
}
// Open file
f, err := os.Open(filepath.Join(path, "train.txt"))
if err != nil {
return nil, err
}
defer f.Close()
// Read data line by line
movies := make([][]int, 0)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
splits := strings.Split(line, "\t")
userId, err := strconv.Atoi(splits[0])
if err != nil {
return nil, err
}
movieId, err := strconv.Atoi(splits[1])
if err != nil {
return nil, err
}
for movieId >= len(movies) {
movies = append(movies, make([]int, 0))
}
movies[movieId] = append(movies[movieId], userId)
}
return movies, nil
}

func jaccard(a, b []int) float32 {
var i, j, intersection int
for i < len(a) && j < len(b) {
if a[i] == b[j] {
intersection++
i++
j++
} else if a[i] < b[j] {
i++
} else {
j++
}
}
if len(a)+len(b)-intersection == 0 {
return 1
}
return 1 - float32(intersection)/float32(len(a)+len(b)-intersection)
}

func TestMovieLens(t *testing.T) {
movies, err := movieLens()
assert.NoError(t, err)

// Create brute-force index
bf := NewBruteforce(jaccard)
for _, movie := range movies {
_, err := bf.Add(movie)
assert.NoError(t, err)
}

// Create HNSW index
hnsw := NewHNSW(jaccard)
for _, movie := range movies {
_, err := hnsw.Add(movie)
assert.NoError(t, err)
}

// Test search
r := 0.0
for i := range movies[:testSize] {
gt, err := bf.SearchIndex(i, 100, false)
assert.NoError(t, err)
scores, err := hnsw.SearchIndex(i, 100, false)
assert.NoError(t, err)
r += recall(gt, scores)
}
r /= float64(testSize)
assert.Greater(t, r, 0.98)
}
19 changes: 6 additions & 13 deletions common/ann/bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package search
package ann

import (
"github.com/juju/errors"
Expand All @@ -23,7 +23,6 @@ import (
// Bruteforce is a naive implementation of vector index.
type Bruteforce[T any] struct {
distanceFunc func(a, b []T) float32
dimension int
vectors [][]T
}

Expand All @@ -32,15 +31,9 @@ func NewBruteforce[T any](distanceFunc func(a, b []T) float32) *Bruteforce[T] {
}

func (b *Bruteforce[T]) Add(v []T) (int, error) {
// Check dimension
if b.dimension == 0 {
b.dimension = len(v)
} else if b.dimension != len(v) {
return 0, errors.Errorf("dimension mismatch: %v != %v", b.dimension, len(v))
}
// Add vector
b.vectors = append(b.vectors, v)
return len(b.vectors) - 1, nil
return len(b.vectors), nil
}

func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
Expand All @@ -62,14 +55,14 @@ func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, flo
scores := make([]lo.Tuple2[int, float32], 0)
for pq.Len() > 0 {
value, score := pq.Pop()
if !prune0 || score < 0 {
if !prune0 || score > 0 {
scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score})
}
}
return scores, nil
}

func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float32] {
// Search
pq := heap.NewPriorityQueue(true)
for i, vec := range b.vectors {
Expand All @@ -82,9 +75,9 @@ func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int
scores := make([]lo.Tuple2[int, float32], 0)
for pq.Len() > 0 {
value, score := pq.Pop()
if !prune0 || score < 0 {
if !prune0 || score > 0 {
scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score})
}
}
return scores, nil
return scores
}
20 changes: 6 additions & 14 deletions common/ann/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package search
package ann

import (
"errors"
"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/samber/lo"
Expand All @@ -28,7 +27,6 @@ import (
// HNSW is a vector index based on Hierarchical Navigable Small Worlds.
type HNSW[T any] struct {
distanceFunc func(a, b []T) float32
dimension int
vectors [][]T
bottomNeighbors []*heap.PriorityQueue
upperNeighbors []map[int32]*heap.PriorityQueue
Expand All @@ -53,41 +51,35 @@ func NewHNSW[T any](distanceFunc func(a, b []T) float32) *HNSW[T] {
}

func (h *HNSW[T]) Add(v []T) (int, error) {
// Check dimension
if h.dimension == 0 {
h.dimension = len(v)
} else if h.dimension != len(v) {
return 0, errors.New("dimension mismatch")
}
// Add vector
h.vectors = append(h.vectors, v)
h.bottomNeighbors = append(h.bottomNeighbors, heap.NewPriorityQueue(false))
h.insert(int32(len(h.vectors) - 1))
return len(h.vectors) - 1, nil
}

func (h *HNSW[T]) Search(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
func (h *HNSW[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
w := h.knnSearch(h.vectors[q], k, h.efSearchValue(k))
scores := make([]lo.Tuple2[int, float32], 0)
for w.Len() > 0 {
value, score := w.Pop()
if !prune0 || score < 0 {
if !prune0 || score > 0 {
scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score})
}
}
return scores, nil
}

func (h *HNSW[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
func (h *HNSW[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float32] {
w := h.knnSearch(q, k, h.efSearchValue(k))
scores := make([]lo.Tuple2[int, float32], 0)
for w.Len() > 0 {
value, score := w.Pop()
if !prune0 || score < 0 {
if !prune0 || score > 0 {
scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score})
}
}
return scores, nil
return scores
}

func (h *HNSW[T]) knnSearch(q []T, k, ef int) *heap.PriorityQueue {
Expand Down
14 changes: 0 additions & 14 deletions common/dataset/dataset_test.go

This file was deleted.

4 changes: 2 additions & 2 deletions common/dataset/dataset.go → common/datautil/datautil.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package dataset
package datautil

import (
"archive/zip"
Expand Down Expand Up @@ -85,7 +85,7 @@ func DownloadAndUnzip(name string) (string, error) {
path := filepath.Join(datasetDir, name)
if _, err := os.Stat(path); os.IsNotExist(err) {
zipFileName, _ := downloadFromUrl(url, tempDir)
if _, err := unzip(zipFileName, path); err != nil {
if _, err := unzip(zipFileName, datasetDir); err != nil {
return "", err
}
}
Expand Down
28 changes: 28 additions & 0 deletions common/datautil/datautil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2025 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package datautil

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestLoadIris(t *testing.T) {
data, target, err := LoadIris()
assert.NoError(t, err)
assert.Len(t, data, 150)
assert.Len(t, data[0], 4)
assert.Len(t, target, 150)
}
6 changes: 3 additions & 3 deletions common/nn/nn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
"github.com/samber/lo"
"github.com/schollz/progressbar/v3"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/common/dataset"
"github.com/zhenghaoz/gorse/common/datautil"
"github.com/zhenghaoz/gorse/common/util"
)

Expand Down Expand Up @@ -91,7 +91,7 @@ func TestNeuralNetwork(t *testing.T) {

func iris() (*Tensor, *Tensor, error) {
// Download dataset
path, err := dataset.DownloadAndUnzip("iris")
path, err := datautil.DownloadAndUnzip("iris")
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -153,7 +153,7 @@ func TestIris(t *testing.T) {
func mnist() (lo.Tuple2[*Tensor, *Tensor], lo.Tuple2[*Tensor, *Tensor], error) {
var train, test lo.Tuple2[*Tensor, *Tensor]
// Download and unzip dataset
path, err := dataset.DownloadAndUnzip("mnist")
path, err := datautil.DownloadAndUnzip("mnist")
if err != nil {
return train, test, err
}
Expand Down
15 changes: 15 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ type RecommendConfig struct {
DataSource DataSourceConfig `mapstructure:"data_source"`
NonPersonalized []NonPersonalizedConfig `mapstructure:"non-personalized" validate:"dive"`
Popular PopularConfig `mapstructure:"popular"`
ItemToItem []ItemToItemConfig `mapstructure:"item-to-item" validate:"dive"`
UserNeighbors NeighborsConfig `mapstructure:"user_neighbors"`
ItemNeighbors NeighborsConfig `mapstructure:"item_neighbors"`
Collaborative CollaborativeConfig `mapstructure:"collaborative"`
Expand Down Expand Up @@ -148,6 +149,20 @@ type NeighborsConfig struct {
IndexFitEpoch int `mapstructure:"index_fit_epoch" validate:"gt=0"`
}

type ItemToItemConfig struct {
Name string `mapstructure:"name" json:"name"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding"`
Column string `mapstructure:"column" json:"column" validate:"item_expr"`
}

func (config *ItemToItemConfig) Hash() string {
hash := md5.New()
hash.Write([]byte(config.Name))
hash.Write([]byte(config.Type))
hash.Write([]byte(config.Column))
return string(hash.Sum(nil))
}

type CollaborativeConfig struct {
ModelFitPeriod time.Duration `mapstructure:"model_fit_period" validate:"gt=0"`
ModelSearchPeriod time.Duration `mapstructure:"model_search_period" validate:"gt=0"`
Expand Down
Loading

0 comments on commit 8d6638b

Please sign in to comment.