Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Promise and Dataloader thread-safe #81

Merged
merged 1 commit into from
Dec 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion promise/async_.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Based on https://github.com/petkaantonov/bluebird/blob/master/src/promise.js
from collections import deque
from threading import local

if False:
from .promise import Promise
from typing import Any, Callable, Optional, Union # flake8: noqa


class Async(object):
class Async(local):
def __init__(self, trampoline_enabled=True):
self.is_tick_used = False
self.late_queue = deque() # type: ignore
Expand Down
14 changes: 7 additions & 7 deletions promise/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
except ImportError:
from collections import Iterable
from functools import partial
from threading import local

from .promise import Promise, async_instance, get_default_scheduler

Expand Down Expand Up @@ -33,7 +34,7 @@ def get_chunks(iterable_obj, chunk_size=1):
Loader = namedtuple("Loader", "key,resolve,reject")


class DataLoader(object):
class DataLoader(local):

batch = True
max_batch_size = None # type: int
Expand Down Expand Up @@ -212,22 +213,21 @@ def prime(self, key, value):
# ensuring that it always occurs after "PromiseJobs" ends.

# Private: cached resolved Promise instance
resolved_promise = None # type: Optional[Promise[None]]

cache = local()

def enqueue_post_promise_job(fn, scheduler):
# type: (Callable, Any) -> None
global resolved_promise
if not resolved_promise:
resolved_promise = Promise.resolve(None)
global cache
if not hasattr(cache, 'resolved_promise'):
cache.resolved_promise = Promise.resolve(None)
if not scheduler:
scheduler = get_default_scheduler()

def on_promise_resolve(v):
# type: (Any) -> None
async_instance.invoke(fn, scheduler)

resolved_promise.then(on_promise_resolve)
cache.resolved_promise.then(on_promise_resolve)


def dispatch_queue(loader):
Expand Down
115 changes: 115 additions & 0 deletions tests/test_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from promise import Promise
from promise.dataloader import DataLoader
import threading



def test_promise_thread_safety():
"""
Promise tasks should never be executed in a different thread from the one they are scheduled from,
unless the ThreadPoolExecutor is used.

Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2
resolves its own promise tasks.
"""
event_1 = threading.Event()
event_2 = threading.Event()

assert_object = {'is_same_thread': True}

def task_1():
thread_name = threading.current_thread().getName()

def then_1(value):
# Enqueue tasks to run later.
# This relies on the fact that `then` does not execute the function synchronously when called from
# within another `then` callback function.
promise = Promise.resolve(None).then(then_2)
assert promise.is_pending
event_1.set() # Unblock main thread
event_2.wait() # Wait for thread 2

def then_2(value):
assert_object['is_same_thread'] = (thread_name == threading.current_thread().getName())

promise = Promise.resolve(None).then(then_1)

def task_2():
promise = Promise.resolve(None).then(lambda v: None)
promise.get() # Drain task queue
event_2.set() # Unblock thread 1

thread_1 = threading.Thread(target=task_1)
thread_1.start()

event_1.wait() # Wait for Thread 1 to enqueue promise tasks

thread_2 = threading.Thread(target=task_2)
thread_2.start()

for thread in (thread_1, thread_2):
thread.join()

assert assert_object['is_same_thread']


def test_dataloader_thread_safety():
"""
Dataloader should only batch `load` calls that happened on the same thread.

Here we assert that `load` calls on thread 2 are not batched on thread 1 as
thread 1 batches its own `load` calls.
"""
def load_many(keys):
thead_name = threading.current_thread().getName()
return Promise.resolve([thead_name for key in keys])

thread_name_loader = DataLoader(load_many)

event_1 = threading.Event()
event_2 = threading.Event()
event_3 = threading.Event()

assert_object = {
'is_same_thread_1': True,
'is_same_thread_2': True,
}

def task_1():
@Promise.safe
def do():
promise = thread_name_loader.load(1)
event_1.set()
event_2.wait() # Wait for thread 2 to call `load`
assert_object['is_same_thread_1'] = (
promise.get() == threading.current_thread().getName()
)
event_3.set() # Unblock thread 2

do().get()

def task_2():
@Promise.safe
def do():
promise = thread_name_loader.load(2)
event_2.set()
event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch`
assert_object['is_same_thread_2'] = (
promise.get() == threading.current_thread().getName()
)

do().get()

thread_1 = threading.Thread(target=task_1)
thread_1.start()

event_1.wait() # Wait for thread 1 to call `load`

thread_2 = threading.Thread(target=task_2)
thread_2.start()

for thread in (thread_1, thread_2):
thread.join()

assert assert_object['is_same_thread_1']
assert assert_object['is_same_thread_2']