From c519261defbb6258fa33a550980b8260daa85e06 Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Sat, 14 Dec 2019 13:51:27 -0500 Subject: [PATCH] make thread safe --- promise/async_.py | 3 +- promise/dataloader.py | 14 ++--- tests/test_thread_safety.py | 115 ++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 tests/test_thread_safety.py diff --git a/promise/async_.py b/promise/async_.py index 48088a7..21ac6e2 100644 --- a/promise/async_.py +++ b/promise/async_.py @@ -1,12 +1,13 @@ # Based on https://github.com/petkaantonov/bluebird/blob/master/src/promise.js from collections import deque +from threading import local if False: from .promise import Promise from typing import Any, Callable, Optional, Union # flake8: noqa -class Async(object): +class Async(local): def __init__(self, trampoline_enabled=True): self.is_tick_used = False self.late_queue = deque() # type: ignore diff --git a/promise/dataloader.py b/promise/dataloader.py index 7b0f9ee..fb779ad 100644 --- a/promise/dataloader.py +++ b/promise/dataloader.py @@ -4,6 +4,7 @@ except ImportError: from collections import Iterable from functools import partial +from threading import local from .promise import Promise, async_instance, get_default_scheduler @@ -33,7 +34,7 @@ def get_chunks(iterable_obj, chunk_size=1): Loader = namedtuple("Loader", "key,resolve,reject") -class DataLoader(object): +class DataLoader(local): batch = True max_batch_size = None # type: int @@ -212,14 +213,13 @@ def prime(self, key, value): # ensuring that it always occurs after "PromiseJobs" ends. # Private: cached resolved Promise instance -resolved_promise = None # type: Optional[Promise[None]] - +cache = local() def enqueue_post_promise_job(fn, scheduler): # type: (Callable, Any) -> None - global resolved_promise - if not resolved_promise: - resolved_promise = Promise.resolve(None) + global cache + if not hasattr(cache, 'resolved_promise'): + cache.resolved_promise = Promise.resolve(None) if not scheduler: scheduler = get_default_scheduler() @@ -227,7 +227,7 @@ def on_promise_resolve(v): # type: (Any) -> None async_instance.invoke(fn, scheduler) - resolved_promise.then(on_promise_resolve) + cache.resolved_promise.then(on_promise_resolve) def dispatch_queue(loader): diff --git a/tests/test_thread_safety.py b/tests/test_thread_safety.py new file mode 100644 index 0000000..ed55a84 --- /dev/null +++ b/tests/test_thread_safety.py @@ -0,0 +1,115 @@ +from promise import Promise +from promise.dataloader import DataLoader +import threading + + + +def test_promise_thread_safety(): + """ + Promise tasks should never be executed in a different thread from the one they are scheduled from, + unless the ThreadPoolExecutor is used. + + Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2 + resolves its own promise tasks. + """ + event_1 = threading.Event() + event_2 = threading.Event() + + assert_object = {'is_same_thread': True} + + def task_1(): + thread_name = threading.current_thread().getName() + + def then_1(value): + # Enqueue tasks to run later. + # This relies on the fact that `then` does not execute the function synchronously when called from + # within another `then` callback function. + promise = Promise.resolve(None).then(then_2) + assert promise.is_pending + event_1.set() # Unblock main thread + event_2.wait() # Wait for thread 2 + + def then_2(value): + assert_object['is_same_thread'] = (thread_name == threading.current_thread().getName()) + + promise = Promise.resolve(None).then(then_1) + + def task_2(): + promise = Promise.resolve(None).then(lambda v: None) + promise.get() # Drain task queue + event_2.set() # Unblock thread 1 + + thread_1 = threading.Thread(target=task_1) + thread_1.start() + + event_1.wait() # Wait for Thread 1 to enqueue promise tasks + + thread_2 = threading.Thread(target=task_2) + thread_2.start() + + for thread in (thread_1, thread_2): + thread.join() + + assert assert_object['is_same_thread'] + + +def test_dataloader_thread_safety(): + """ + Dataloader should only batch `load` calls that happened on the same thread. + + Here we assert that `load` calls on thread 2 are not batched on thread 1 as + thread 1 batches its own `load` calls. + """ + def load_many(keys): + thead_name = threading.current_thread().getName() + return Promise.resolve([thead_name for key in keys]) + + thread_name_loader = DataLoader(load_many) + + event_1 = threading.Event() + event_2 = threading.Event() + event_3 = threading.Event() + + assert_object = { + 'is_same_thread_1': True, + 'is_same_thread_2': True, + } + + def task_1(): + @Promise.safe + def do(): + promise = thread_name_loader.load(1) + event_1.set() + event_2.wait() # Wait for thread 2 to call `load` + assert_object['is_same_thread_1'] = ( + promise.get() == threading.current_thread().getName() + ) + event_3.set() # Unblock thread 2 + + do().get() + + def task_2(): + @Promise.safe + def do(): + promise = thread_name_loader.load(2) + event_2.set() + event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch` + assert_object['is_same_thread_2'] = ( + promise.get() == threading.current_thread().getName() + ) + + do().get() + + thread_1 = threading.Thread(target=task_1) + thread_1.start() + + event_1.wait() # Wait for thread 1 to call `load` + + thread_2 = threading.Thread(target=task_2) + thread_2.start() + + for thread in (thread_1, thread_2): + thread.join() + + assert assert_object['is_same_thread_1'] + assert assert_object['is_same_thread_2']