Skip to content

Commit

Permalink
support tag-based item-to-item reommendations (#926)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 19, 2025
1 parent 8d6638b commit 150d375
Show file tree
Hide file tree
Showing 10 changed files with 500 additions and 183 deletions.
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ type NeighborsConfig struct {

type ItemToItemConfig struct {
Name string `mapstructure:"name" json:"name"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags"`
Column string `mapstructure:"column" json:"column" validate:"item_expr"`
}

Expand Down
37 changes: 30 additions & 7 deletions dataset/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,28 @@
package dataset

import (
"github.com/chewxy/math32"
"github.com/samber/lo"
"github.com/zhenghaoz/gorse/storage/data"
"modernc.org/strutil"
"time"
)

type ID int

type Dataset struct {
timestamp time.Time
items []data.Item
timestamp time.Time
items []data.Item
columnNames *strutil.Pool
columnValues *FreqDict
}

func NewDataset(timestamp time.Time, itemCount int) *Dataset {
return &Dataset{
timestamp: timestamp,
items: make([]data.Item, 0, itemCount),
timestamp: timestamp,
items: make([]data.Item, 0, itemCount),
columnNames: strutil.NewPool(),
columnValues: NewFreqDict(),
}
}

Expand All @@ -40,32 +48,47 @@ func (d *Dataset) GetItems() []data.Item {
return d.items
}

func (d *Dataset) GetItemColumnValuesIDF() []float32 {
idf := make([]float32, d.columnValues.Count())
for i := 0; i < d.columnValues.Count(); i++ {
// Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3.
idf[i] = max(math32.Log(float32(len(d.items)/(d.columnValues.Freq(i)))), 1e-3)
}
return idf
}

func (d *Dataset) AddItem(item data.Item) {
d.items = append(d.items, data.Item{
ItemId: item.ItemId,
IsHidden: item.IsHidden,
Categories: item.Categories,
Timestamp: item.Timestamp,
Labels: d.processLabels(item.Labels),
Labels: d.processLabels(item.Labels, ""),
Comment: item.Comment,
})
}

func (d *Dataset) processLabels(labels any) any {
func (d *Dataset) processLabels(labels any, parent string) any {
switch typed := labels.(type) {
case map[string]any:
o := make(map[string]any)
for k, v := range typed {
o[k] = d.processLabels(v)
o[d.columnNames.Align(k)] = d.processLabels(v, parent+"."+k)
}
return o
case []any:
if isSliceOf[float64](typed) {
return lo.Map(typed, func(e any, _ int) float32 {
return float32(e.(float64))
})
} else if isSliceOf[string](typed) {
return lo.Map(typed, func(e any, _ int) ID {
return ID(d.columnValues.Id(parent + ":" + e.(string)))
})
}
return typed
case string:
return ID(d.columnValues.Id(parent + ":" + typed))
default:
return labels
}
Expand Down
59 changes: 58 additions & 1 deletion dataset/dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package dataset

import (
"github.com/chewxy/math32"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/storage/data"
"testing"
Expand All @@ -31,10 +32,24 @@ func TestDataset_AddItem(t *testing.T) {
Labels: map[string]any{
"a": 1,
"embedded": []any{1.1, 2.2, 3.3},
"tags": []any{"a", "b", "c"},
},
Comment: "comment",
})
assert.Len(t, dataSet.GetItems(), 1)
dataSet.AddItem(data.Item{
ItemId: "2",
IsHidden: true,
Categories: []string{"a", "b"},
Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
Labels: map[string]any{
"a": 1,
"embedded": []any{1.1, 2.2, 3.3},
"tags": []any{"b", "c", "a"},
"topics": []any{"a", "b", "c"},
},
Comment: "comment",
})
assert.Len(t, dataSet.GetItems(), 2)
assert.Equal(t, data.Item{
ItemId: "1",
IsHidden: false,
Expand All @@ -43,7 +58,49 @@ func TestDataset_AddItem(t *testing.T) {
Labels: map[string]any{
"a": 1,
"embedded": []float32{1.1, 2.2, 3.3},
"tags": []ID{1, 2, 3},
},
Comment: "comment",
}, dataSet.GetItems()[0])
assert.Equal(t, data.Item{
ItemId: "2",
IsHidden: true,
Categories: []string{"a", "b"},
Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
Labels: map[string]any{
"a": 1,
"embedded": []float32{1.1, 2.2, 3.3},
"tags": []ID{2, 3, 1},
"topics": []ID{4, 5, 6},
},
Comment: "comment",
}, dataSet.GetItems()[1])
}

func TestDataset_GetItemColumnValuesIDF(t *testing.T) {
dataSet := NewDataset(time.Now(), 1)
dataSet.AddItem(data.Item{
ItemId: "1",
IsHidden: false,
Categories: []string{"a", "b"},
Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
Labels: map[string]any{
"tags": []any{"a", "b", "c"},
},
Comment: "comment",
})
dataSet.AddItem(data.Item{
ItemId: "2",
IsHidden: false,
Categories: []string{"a", "b"},
Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
Labels: map[string]any{
"tags": []any{"a", "e"},
},
Comment: "comment",
})
idf := dataSet.GetItemColumnValuesIDF()
assert.Len(t, idf, 5)
assert.InDelta(t, 1e-3, idf[1], 1e-6)
assert.InDelta(t, math32.Log(2), idf[2], 1e-6)
}
58 changes: 58 additions & 0 deletions dataset/dict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2025 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dataset

type FreqDict struct {
si map[string]int
is []string
cnt []int
}

func NewFreqDict() (d *FreqDict) {
d = &FreqDict{map[string]int{}, []string{}, []int{}}
d.Id("")
return
}

func (d *FreqDict) Count() int {
return len(d.is)
}

func (d *FreqDict) Id(s string) (y int) {
if y, ok := d.si[s]; ok {
d.cnt[y]++
return y
}

y = len(d.is)
d.si[s] = y
d.is = append(d.is, s)
d.cnt = append(d.cnt, 1)
return
}

func (d *FreqDict) String(id int) (s string, ok bool) {
if id >= len(d.is) {
return "", false
}
return d.is[id], true
}

func (d *FreqDict) Freq(id int) int {
if id >= len(d.cnt) {
return 0
}
return d.cnt[id]
}
35 changes: 35 additions & 0 deletions dataset/dict_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2025 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dataset

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestFreqDict(t *testing.T) {
dict := NewFreqDict()
assert.Equal(t, 0, dict.Id(""))
assert.Equal(t, 1, dict.Id("a"))
assert.Equal(t, 2, dict.Id("b"))
assert.Equal(t, 2, dict.Id("b"))
assert.Equal(t, 3, dict.Id("c"))
assert.Equal(t, 3, dict.Id("c"))
assert.Equal(t, 3, dict.Id("c"))
assert.Equal(t, 4, dict.Count())
assert.Equal(t, 1, dict.Freq(1))
assert.Equal(t, 2, dict.Freq(2))
assert.Equal(t, 3, dict.Freq(3))
}
Loading

0 comments on commit 150d375

Please sign in to comment.