diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..52fa6f7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.cache +__pycache__ +/fcache.egg-info diff --git a/fcache/__init__.py b/fcache/__init__.py new file mode 100644 index 0000000..e66d264 --- /dev/null +++ b/fcache/__init__.py @@ -0,0 +1 @@ +from .fcache import fcache diff --git a/fcache/fcache.py b/fcache/fcache.py new file mode 100644 index 0000000..5af3744 --- /dev/null +++ b/fcache/fcache.py @@ -0,0 +1,30 @@ +from fcache.hashing import stable_hash +from fcache.file_cache import FileCache + +import atexit + + +GLOBAL_CACHE = FileCache('.fcache') + +def get_global_cache(): + global GLOBAL_CACHE + return GLOBAL_CACHE + +def fcache(f): + cache = get_global_cache() + def decorated(*args, **kwargs): + call_hash = stable_hash((f, args, kwargs)) + if call_hash in cache: + return cache[call_hash] + else: + ret_val = f(*args, **kwargs) + cache[call_hash] = ret_val + return ret_val + return decorated +fcache.clear_at_exit = False + +@atexit.register +def maybe_clear_at_exit(): + if fcache.clear_at_exit: + global GLOBAL_CACHE + GLOBAL_CACHE.clear() diff --git a/fcache/file_cache.py b/fcache/file_cache.py new file mode 100644 index 0000000..f98678b --- /dev/null +++ b/fcache/file_cache.py @@ -0,0 +1,42 @@ +import os +import pickle + + +class FileCache: + def __init__(self, cache_dir): + self._cache_dir = cache_dir + self._cached_keys = set() + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + for cached_pickle in os.listdir(cache_dir): + key, ext = os.path.splitext(cached_pickle) + if ext != '.pkl': continue + self._cached_keys.add(int(key)) + + def __setitem__(self, key, value): + cache_fn = self.cache_fn(key) + with open(cache_fn, 'wb') as f: + pickle.dump(value, f) + self._cached_keys.add(key) + + def __getitem__(self, key): + cache_fn = self.cache_fn(key) + try: + with open(cache_fn, 'rb') as f: + return pickle.load(f) + except IOError: + raise KeyError(key) + + def __len__(self): + return len(self._cached_keys) + + def __contains__(self, key): + return key in self._cached_keys + + def clear(self): + self._cached_keys = set() + for cached_pickle in os.listdir(self._cache_dir): + os.unlink(os.path.join(self._cache_dir, cached_pickle)) + + def cache_fn(self, key): + return os.path.join(self._cache_dir, str(key) + '.pkl') diff --git a/fcache/hashing.py b/fcache/hashing.py new file mode 100644 index 0000000..8562d45 --- /dev/null +++ b/fcache/hashing.py @@ -0,0 +1,24 @@ +import hashlib +import pickle +import io +from collections import OrderedDict + + +class StablePickler(pickle.Pickler): + def persistent_id(self, obj): + if hasattr(obj, '__code__'): + return obj.__code__.co_code + elif isinstance(obj, dict): + return sorted(obj.items()) + elif isinstance(obj, set): + return sorted(obj) + return None + + +def stable_hash(obj): + file = io.BytesIO() + StablePickler(file).dump(obj) + dumps = file.getvalue() + hasher = hashlib.sha1() + hasher.update(dumps) + return int.from_bytes(hasher.digest(), 'little') diff --git a/fcache/test_file_cache.py b/fcache/test_file_cache.py new file mode 100644 index 0000000..3eafb17 --- /dev/null +++ b/fcache/test_file_cache.py @@ -0,0 +1,28 @@ +from fcache.file_cache import FileCache + +def test_caching_values(tmpdir): + cache = FileCache(str(tmpdir)) + kvs = [(1, 1), (2, 20), (3, 300)] + for k, v in kvs: + cache[k] = v + for k, v in kvs: + assert k in cache + assert cache[k] == v + assert len(cache) == 3 + + # test that values are preserved between runs + del cache + cache = FileCache(str(tmpdir)) + for k, v in kvs: + assert k in cache + assert cache[k] == v + assert len(cache) == 3 + +def test_cache_clearing(tmpdir): + cache = FileCache(str(tmpdir)) + kvs = [(1, 1), (2, 20), (3, 300)] + for k, v in kvs: + cache[k] = v + assert len(cache) == 3 + cache.clear() + assert len(cache) == 0 diff --git a/fcache/tests/test_caching_functions.py b/fcache/tests/test_caching_functions.py new file mode 100644 index 0000000..cfbf4ca --- /dev/null +++ b/fcache/tests/test_caching_functions.py @@ -0,0 +1,95 @@ +from fcache import fcache + +fcache.clear_at_exit = True + +class Num: + def __init__(self, n): + self.n = n + + def __eq__(self, other): + return self.n == other.n + + def __hash__(self): + return hash(self.n) + +# Test global function + +global_function_call_counter = 0 +def global_function(n): + global global_function_call_counter + global_function_call_counter += 1 + return n * n + +def test_global_function(): + cached = fcache(global_function) + exp_5 = global_function(5) + exp_1000 = global_function(1000) + exp_large = global_function(10000000) + global global_function_call_counter + assert global_function_call_counter == 3 + + for _ in range(2): + assert cached(5) == exp_5 + assert cached(1000) == exp_1000 + assert cached(10000000) == exp_large + assert global_function_call_counter == 6 + + +# Test local function +def test_local_function(): + def loc_func(n): + loc_func._num_calls_ += 1 + return n * n + loc_func._num_calls_ = 0 + cached = fcache(loc_func) + for i in range(5): + for n in range(5): + assert cached(n) == loc_func(n) + assert loc_func._num_calls_ == 5 * (i + 2) + + +# Test lambda +def test_lambda(): + lamb = lambda x: x + cached = fcache(lamb) + for i in range(5): + for n in range(5): + f_input = Num(n) + actual_res = lamb(f_input) + cached_res = cached(f_input) + assert actual_res == cached_res + if i == 0: + assert f_input is cached_res + else: + assert f_input is not cached_res + +# Test redifining a function +def test_redefining_a_function(): + @fcache + def f1(n): + return n + 1 + for n in range(5): + return f1(n) == (n + 1) + + @fcache + def f1(n): + return n + 2 + for n in range(5): + return f1(n) == (n + 2) + + +# Test kwargs +def test_kwargs(): + @fcache + def fun(n, add=1): + return n + add + + assert fun(10) == 11 + assert fun(10, add=20) == 30 + + +# TODO Test closure +# TODO Test class function +# TODO Test well behaived + +# Test with numpy and pandas - in a separate file diff --git a/fcache/tests/test_hashing.py b/fcache/tests/test_hashing.py new file mode 100644 index 0000000..de040b1 --- /dev/null +++ b/fcache/tests/test_hashing.py @@ -0,0 +1,31 @@ +import pytest +import sys +import subprocess +import os + +from fcache.hashing import stable_hash + + +def get_hash_command(repr_): + return 'from fcache.hashing import stable_hash; print(stable_hash(%s))' % repr_ + + +def get_hash(repr_): + # copy pasted from https://hg.python.org/cpython/file/5e8fa1b13516/Lib/test/test_hash.py#l145 + env = os.environ + cmd_line = [sys.executable, '-c', get_hash_command(repr_)] + p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + env=env) + out, err = p.communicate() + return int(out.strip()) + + +@pytest.mark.parametrize('object_to_hash', ['string', (('key1', 1), ('key2', 2)), + {'key1': 10, 'key2': 20}, {1, 50, 10, 20}, + 5, ('str', 10), [2, 3, 'xv']]) +def test_cache_stability(object_to_hash): + expected_hash = stable_hash(object_to_hash) + for _ in range(3): + another_hash = get_hash(repr(object_to_hash)) + assert expected_hash == another_hash diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1e02ce8 --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from distutils.core import setup +from setuptools import find_packages + +setup( + name='fcache', + version='0.1.0', + description='Caching Function Calls', + author='Svetlin Mladenov', + packages=find_packages(), + install_requires=[ + ], +) +