Skip to content

Commit

Permalink
Add threads flag to parallel apply (#1171)
Browse files Browse the repository at this point in the history
* Add use_threads argument to parallel_apply

* Cleanup imports

* Update last modified date

* Add brief explanation of threads vs processes
  • Loading branch information
alexgleith authored Jan 22, 2024
1 parent f41432d commit ad849c4
Showing 1 changed file with 30 additions and 22 deletions.
52 changes: 30 additions & 22 deletions Tools/dea_tools/datahandling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## dea_datahandling.py
# dea_datahandling.py
"""
Loading and manipulating Digital Earth Australia products and data
using the Open Data Cube and xarray.
Expand All @@ -17,31 +17,28 @@
If you would like to report an issue with this script, you can file one
on Github (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
Last modified: June 2023
Last modified: Jan 2024
"""

import datetime

# Import required packages
import os
import zipfile
import datetime
import requests
import warnings
import odc.algo
import dask
import zipfile
from collections import Counter

import numpy as np
import odc.algo
import pandas as pd
import dask.array as da
import xarray as xr
import skimage.transform
import requests
import sklearn.decomposition
from skimage.exposure import match_histograms
from skimage.color import rgb2hsv, hsv2rgb
from random import randint
from collections import Counter
import xarray as xr
from datacube.utils.dates import normalise_dt
from odc.algo import mask_cleanup
from datacube.utils import masking
from scipy.ndimage import binary_dilation
from datacube.utils.dates import normalise_dt
from skimage.color import hsv2rgb, rgb2hsv
from skimage.exposure import match_histograms


def _dc_query_only(**kw):
Expand Down Expand Up @@ -905,7 +902,7 @@ def nearest(
return nearest_array


def parallel_apply(ds, dim, func, *args, **kwargs):
def parallel_apply(ds, dim, func, use_threads=False, *args, **kwargs):
"""
Applies a custom function in parallel along the dimension of an
xarray.Dataset or xarray.DataArray.
Expand All @@ -929,6 +926,11 @@ def parallel_apply(ds, dim, func, *args, **kwargs):
The function that will be applied in parallel to each array
along dimension `dim`. The first argument passed to this
function should be the array along `dim`.
use_threads : bool, optional
Whether to use threads instead of processes for parallelisation.
Defaults to False, which means it'll use multi-processing.
In brief, the difference between threads and processes is that threads
share memory, while processes have separate memory.
*args :
Any number of arguments that will be passed to `func`.
**kwargs :
Expand All @@ -941,13 +943,19 @@ def parallel_apply(ds, dim, func, *args, **kwargs):
along the input `dim` dimension.
"""

from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from itertools import repeat
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from functools import partial
from itertools import repeat

with ProcessPoolExecutor() as executor:

from tqdm import tqdm

# Use threads or processes
if use_threads:
Executor = ThreadPoolExecutor
else:
Executor = ProcessPoolExecutor

with Executor as executor:
# Update func to add kwargs
func = partial(func, **kwargs)

Expand Down

0 comments on commit ad849c4

Please sign in to comment.