diff --git a/README.md b/README.md index 7e67724..ebb244b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,11 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/isgj/collection.svg)](https://pkg.go.dev/github.com/isgj/collection) + # collection + Generic go structures ## Install + ``` go get github.com/isgj/collection ``` @@ -91,3 +94,5 @@ func main() { [collection.Set](https://pkg.go.dev/github.com/isgj/collection#Set) [collection.DLList](https://pkg.go.dev/github.com/isgj/collection#DLList) + +[collection.LRUCache](https://pkg.go.dev/github.com/isgj/collection#LRUCache) diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..c139061 --- /dev/null +++ b/cache.go @@ -0,0 +1,169 @@ +package collection + +// LRUCache implements a least recently used cache +type LRUCache[K comparable, V any] struct { + size int + head *cnode[K, V] + tail *cnode[K, V] + cached map[K]*cnode[K, V] +} + +// NewLRUCache creates a new LRUCache. +// If the size is 0 or negative, the cache is unbounded. +func NewCache[K comparable, V any](size int) *LRUCache[K, V] { + return &LRUCache[K, V]{size: size, cached: make(map[K]*cnode[K, V])} +} + +// Clear removes all items from the cache. +func (c *LRUCache[K, V]) Clear() { + c.cached = make(map[K]*cnode[K, V]) + c.head = nil + c.tail = nil +} + +// Get returns the value for the given key if present in the cache. +func (c *LRUCache[K, V]) Get(key K) (val V, ok bool) { + node, ok := c.cached[key] + if !ok { + return val, ok + } + c.moveToHead(node) + return node.val, ok +} + +// GetOrAdd returns the value for the given key if present in the cache. +// If not, it adds the value returned bu f and returns the given value. +func (c *LRUCache[K, V]) GetOrAdd(key K, f func() V) V { + node, ok := c.Get(key) + if ok { + return node + } + val := f() + c.Put(key, val) + return val +} + +// IterKeys returns an iterator over the keys in the cache. +// The keys are returned from the least recently used to the last one. +func (c *LRUCache[K, V]) IterKeys() Iterator[K] { + cur_node := c.head + return func() (k K, ok bool) { + if cur_node == nil { + return k, false + } + k, cur_node = cur_node.key, cur_node.next + return k, true + } +} + +// IterVals returns an iterator over the values in the cache. +// The values are returned from the least recently used to the last one. +func (c *LRUCache[K, V]) IterVals() Iterator[V] { + cur_node := c.head + return func() (v V, ok bool) { + if cur_node == nil { + return v, false + } + v, cur_node = cur_node.val, cur_node.next + return v, true + } +} + +// IsFull returns true if the cache is full. +func (c *LRUCache[K, V]) IsFull() bool { + return len(c.cached) == c.size && c.size > 0 +} + +// IsEmpty returns true if the cache is empty. +func (c *LRUCache[K, V]) IsEmpty() bool { + return len(c.cached) == 0 +} + +// Len returns the number of items in the cache. +func (c *LRUCache[K, V]) Len() int { + return len(c.cached) +} + +// Put adds the given key-value pair to the cache. +func (c *LRUCache[K, V]) Put(key K, val V) { + node, ok := c.cached[key] + if ok { + node.val = val + c.moveToHead(node) + return + } + if c.size > 0 && len(c.cached) >= c.size { + c.removeTail() + } + node = &cnode[K, V]{key: key, val: val} + c.cached[key] = node + // Add the first node + if c.head == nil { + c.head, c.tail = node, node + return + } + c.moveToHead(node) +} + +// ReverseIterKeys returns an iterator over the keys in the cache. +// The keys are returned from the last used to the least recently one. +func (c *LRUCache[K, V]) ReverseIterKeys() Iterator[K] { + cur_node := c.tail + return func() (k K, ok bool) { + if cur_node == nil { + return k, false + } + k, cur_node = cur_node.key, cur_node.prev + return k, true + } +} + +// IterVals returns an iterator over the values in the cache. +// The values are returned from the last used to the least recently one. +func (c *LRUCache[K, V]) ReverseIterVals() Iterator[V] { + cur_node := c.tail + return func() (v V, ok bool) { + if cur_node == nil { + return v, false + } + v, cur_node = cur_node.val, cur_node.prev + return v, true + } +} + +func (c *LRUCache[K, V]) moveToHead(node *cnode[K, V]) { + if node == c.head { + return + } + if node == c.tail { + c.tail = node.prev + } + if node.prev != nil { + node.prev.next = node.next + } + if node.next != nil { + node.next.prev = node.prev + } + node.prev = nil + node.next = c.head + c.head.prev = node + c.head = node +} + +func (c *LRUCache[K, V]) removeTail() { + if c.tail == nil { + return + } + if c.tail.prev != nil { + c.tail.prev.next = nil + } + delete(c.cached, c.tail.key) + c.tail = c.tail.prev +} + +type cnode[K comparable, V any] struct { + key K + val V + prev *cnode[K, V] + next *cnode[K, V] +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..82e74c7 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,240 @@ +package collection + +import ( + "testing" +) + +func TestNewCacheWithLimit(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + cache.Put(4, 4) + if cache.Len() != 3 { + t.Errorf("cache.Size() = %d, want %d", cache.Len(), 3) + } +} + +func TestLeastUsedIsEvicted(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + cache.Get(1) // 2 becomes the least used + cache.Put(4, 4) + if v, ok := cache.Get(2); ok { + t.Errorf("cache.Get(1) = %d, want %d", v, 0) + } +} + +func TestCacheClear(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + cache.Clear() + if cache.Len() != 0 { + t.Errorf("cache.Size() = %d, want %d", cache.Len(), 0) + } +} + +func TestCacheGet(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + cache.Put(4, 4) // 1 is evicted + if v, ok := cache.Get(1); ok { + t.Errorf("cache.Get(1) = %d, %t, want %d, %t", v, ok, 0, false) + } + if v, ok := cache.Get(2); !ok || v != 2 { + t.Errorf("cache.Get(2) = %d, %t, want %d, %t", v, ok, 2, true) + } + if v, ok := cache.Get(3); !ok || v != 3 { + t.Errorf("cache.Get(3) = %d, %t, want %d, %t", v, ok, 3, true) + } + if v, ok := cache.Get(4); !ok || v != 4 { + t.Errorf("cache.Get(4) = %d, %t, want %d, %t", v, ok, 4, true) + } +} + +func TestCacheGetOrAdd(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + if v, ok := cache.Get(1); !ok || v != 1 { + t.Errorf("cache.Get(1) = %d, %t, want %d, %t", v, ok, 1, true) + } + if v := cache.GetOrAdd(1, func() int { return 10 }); v != 1 { + t.Errorf("cache.GetOrAdd(1) = %d, want %d", v, 1) + } + if v, ok := cache.Get(1); !ok || v != 1 { + t.Errorf("cache.Get(1) = %d, %t, want %d, %t", v, ok, 1, true) + } + if v := cache.GetOrAdd(4, func() int { return 10 }); v != 10 { + t.Errorf("cache.GetOrAdd(4) = %d, want %d", v, 10) + } + if v, ok := cache.Get(4); !ok || v != 10 { + t.Errorf("cache.Get(4) = %d, %t, want %d, %t", v, ok, 10, false) + } +} + +func TestCacheLen(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + if cache.Len() != 3 { + t.Errorf("cache.Len() = %d, want %d", cache.Len(), 3) + } + cache.Put(4, 4) // 1 is evicted + if cache.Len() != 3 { + t.Errorf("cache.Len() = %d, want %d", cache.Len(), 3) + } +} + +func TestCachePutLeastUsed(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + cache.Put(1, 10) // value of key 1 should be 10, and it should be the head + cache.Put(4, 4) // 2 is evicted + cache.Put(5, 5) // 3 is evicted + if v, ok := cache.Get(1); !ok || v != 10 { + t.Errorf("cache.Get(1) = %d, %t, want %d, %t", v, ok, 10, true) + } + if v, ok := cache.Get(2); ok { + t.Errorf("cache.Get(2) = %d, %t, want %d, %t", v, ok, 0, false) + } + if v, ok := cache.Get(3); ok { + t.Errorf("cache.Get(3) = %d, %t, want %d, %t", v, ok, 0, false) + } + if v, ok := cache.Get(4); !ok || v != 4 { + t.Errorf("cache.Get(4) = %d, %t, want %d, %t", v, ok, 4, true) + } + if v, ok := cache.Get(5); !ok || v != 5 { + t.Errorf("cache.Get(4) = %d, %t, want %d, %t", v, ok, 5, true) + } + if cache.Len() != 3 { + t.Errorf("cache.Len() = %d, want %d", cache.Len(), 3) + } +} + +func TestCachePutLastUsed(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 1) + cache.Put(2, 2) + cache.Put(3, 3) + cache.Put(3, 10) + cache.Put(4, 4) // 1 is evicted + cache.Put(5, 5) // 2 is evicted + if v, ok := cache.Get(1); ok { + t.Errorf("cache.Get(1) = %d, %t, want %d, %t", v, ok, 0, false) + } + if v, ok := cache.Get(2); ok { + t.Errorf("cache.Get(2) = %d, %t, want %d, %t", v, ok, 0, false) + } + if v, ok := cache.Get(3); !ok || v != 10 { + t.Errorf("cache.Get(3) = %d, %t, want %d, %t", v, ok, 10, true) + } + if v, ok := cache.Get(4); !ok || v != 4 { + t.Errorf("cache.Get(4) = %d, %t, want %d, %t", v, ok, 4, true) + } + if v, ok := cache.Get(5); !ok || v != 5 { + t.Errorf("cache.Get(4) = %d, %t, want %d, %t", v, ok, 5, true) + } + if cache.Len() != 3 { + t.Errorf("cache.Len() = %d, want %d", cache.Len(), 3) + } +} + +func TestCacheIterKeys(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 4) + cache.Put(2, 5) + cache.Put(3, 6) + keys := cache.IterKeys().Collect() + for ind, v := range []int{3, 2, 1} { + if keys[ind] != v { + t.Errorf("cache.IterKeys()[%d] = %d, want %d", ind, keys[ind], v) + } + } +} + +func TestCacheIterVales(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 4) + cache.Put(2, 5) + cache.Put(3, 6) + keys := cache.IterVals().Collect() + for ind, v := range []int{6, 5, 4} { + if keys[ind] != v { + t.Errorf("cache.IterVals()[%d] = %d, want %d", ind, keys[ind], v) + } + } +} + +func TestCacheReverseIterKeys(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 4) + cache.Put(2, 5) + cache.Put(3, 6) + keys := cache.ReverseIterKeys().Collect() + for ind, v := range []int{1, 2, 3} { + if keys[ind] != v { + t.Errorf("cache.ReverseIterKeys()[%d] = %d, want %d", ind, keys[ind], v) + } + } +} + +func TestCacheReverseIterVales(t *testing.T) { + cache := NewCache[int, int](3) + cache.Put(1, 4) + cache.Put(2, 5) + cache.Put(3, 6) + keys := cache.ReverseIterVals().Collect() + for ind, v := range []int{4, 5, 6} { + if keys[ind] != v { + t.Errorf("cache.ReverseIterVals()[%d] = %d, want %d", ind, keys[ind], v) + } + } +} + +func TestCacheIsEmpty(t *testing.T) { + cache := NewCache[int, int](3) + if !cache.IsEmpty() { + t.Errorf("cache.IsEmpty() = %t, want %t", cache.IsEmpty(), true) + } + cache.Put(1, 4) + if cache.IsEmpty() { + t.Errorf("cache.IsEmpty() = %t, want %t", cache.IsEmpty(), false) + } +} + +func TestCacheIsFullWithSize(t *testing.T) { + cache := NewCache[int, int](3) + if cache.IsFull() { + t.Errorf("cache.IsFull() = %t, want %t", cache.IsFull(), false) + } + cache.Put(1, 4) + cache.Put(2, 5) + cache.Put(3, 6) + if !cache.IsFull() { + t.Errorf("cache.IsFull() = %t, want %t", cache.IsFull(), true) + } +} + +func TestCacheIsFullWithoutSize(t *testing.T) { + cache := NewCache[int, int](0) + if cache.IsFull() { + t.Errorf("cache.IsFull() = %t, want %t", cache.IsFull(), false) + } + cache.Put(1, 4) + cache.Put(2, 5) + cache.Put(3, 6) + if cache.IsFull() { + t.Errorf("cache.IsFull() = %t, want %t", cache.IsFull(), false) + } +}