diff --git a/doc/progress.rst b/doc/progress.rst index b8e6864a8..d3d33caf6 100644 --- a/doc/progress.rst +++ b/doc/progress.rst @@ -12,10 +12,12 @@ Changelog * FIX#1058, #1100: Avoid ``NoneType`` error when printing task without ``class_labels`` attribute. * FIX#1110: Make arguments to ``create_study`` and ``create_suite`` that are defined as optional by the OpenML XSD actually optional. * FIX#1147: ``openml.flow.flow_exists`` no longer requires an API key. + * FIX#1184: Automatically resolve proxies when downloading from minio. Turn this off by setting environment variable ``no_proxy="*"``. * MAIN#1088: Do CI for Windows on Github Actions instead of Appveyor. * MAINT#1104: Fix outdated docstring for ``list_task``. * MAIN#1146: Update the pre-commit dependencies. * ADD#1103: Add a ``predictions`` property to OpenMLRun for easy accessibility of prediction data. + * ADD#1188: EXPERIMENTAL. Allow downloading all files from a minio bucket with ``download_all_files=True`` for ``get_dataset``. 0.12.2 diff --git a/openml/_api_calls.py b/openml/_api_calls.py index 7db1155cc..f3c3306fc 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -12,6 +12,7 @@ import xmltodict from urllib3 import ProxyManager from typing import Dict, Optional, Union +import zipfile import minio @@ -44,6 +45,7 @@ def resolve_env_proxies(url: str) -> Optional[str]: selected_proxy = requests.utils.select_proxy(url, resolved_proxies) return selected_proxy + def _create_url_from_endpoint(endpoint: str) -> str: url = config.server if not url.endswith("/"): @@ -137,11 +139,7 @@ def _download_minio_file( proxy_client = ProxyManager(proxy) if proxy else None - client = minio.Minio( - endpoint=parsed_url.netloc, - secure=False, - http_client=proxy_client - ) + client = minio.Minio(endpoint=parsed_url.netloc, secure=False, http_client=proxy_client) try: client.fget_object( @@ -149,6 +147,10 @@ def _download_minio_file( object_name=object_name, file_path=str(destination), ) + if destination.is_file() and destination.suffix == ".zip": + with zipfile.ZipFile(destination, "r") as zip_ref: + zip_ref.extractall(destination.parent) + except minio.error.S3Error as e: if e.message.startswith("Object does not exist"): raise FileNotFoundError(f"Object at '{source}' does not exist.") from e @@ -157,6 +159,39 @@ def _download_minio_file( raise FileNotFoundError("Bucket does not exist or is private.") from e +def _download_minio_bucket( + source: str, + destination: Union[str, pathlib.Path], + exists_ok: bool = True, +) -> None: + """Download file ``source`` from a MinIO Bucket and store it at ``destination``. + + Parameters + ---------- + source : Union[str, pathlib.Path] + URL to a MinIO bucket. + destination : str + Path to a directory to store the bucket content in. + exists_ok : bool, optional (default=True) + If False, raise FileExists if a file already exists in ``destination``. + """ + + destination = pathlib.Path(destination) + parsed_url = urllib.parse.urlparse(source) + + # expect path format: /BUCKET/path/to/file.ext + bucket = parsed_url.path[1:] + + client = minio.Minio(endpoint=parsed_url.netloc, secure=False) + + for file_object in client.list_objects(bucket, recursive=True): + _download_minio_file( + source=source + "/" + file_object.object_name, + destination=pathlib.Path(destination, file_object.object_name), + exists_ok=True, + ) + + def _download_text_file( source: str, output_path: Optional[str] = None, diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index 1e6fb5c78..770413a23 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -5,6 +5,7 @@ import os from pyexpat import ExpatError from typing import List, Dict, Union, Optional, cast +import warnings import numpy as np import arff @@ -356,6 +357,7 @@ def get_dataset( error_if_multiple: bool = False, cache_format: str = "pickle", download_qualities: bool = True, + download_all_files: bool = False, ) -> OpenMLDataset: """Download the OpenML dataset representation, optionally also download actual data file. @@ -389,11 +391,20 @@ def get_dataset( no.of.rows is very high. download_qualities : bool (default=True) Option to download 'qualities' meta-data in addition to the minimal dataset description. + download_all_files: bool (default=False) + EXPERIMENTAL. Download all files related to the dataset that reside on the server. + Useful for datasets which refer to auxiliary files (e.g., meta-album). + Returns ------- dataset : :class:`openml.OpenMLDataset` The downloaded dataset. """ + if download_all_files: + warnings.warn( + "``download_all_files`` is experimental and is likely to break with new releases." + ) + if cache_format not in ["feather", "pickle"]: raise ValueError( "cache_format must be one of 'feather' or 'pickle. " @@ -434,7 +445,12 @@ def get_dataset( arff_file = _get_dataset_arff(description) if download_data else None if "oml:minio_url" in description and download_data: - parquet_file = _get_dataset_parquet(description) + try: + parquet_file = _get_dataset_parquet( + description, download_all_files=download_all_files + ) + except urllib3.exceptions.MaxRetryError: + parquet_file = None else: parquet_file = None remove_dataset_cache = False @@ -967,7 +983,9 @@ def _get_dataset_description(did_cache_dir, dataset_id): def _get_dataset_parquet( - description: Union[Dict, OpenMLDataset], cache_directory: str = None + description: Union[Dict, OpenMLDataset], + cache_directory: str = None, + download_all_files: bool = False, ) -> Optional[str]: """Return the path to the local parquet file of the dataset. If is not cached, it is downloaded. @@ -987,23 +1005,40 @@ def _get_dataset_parquet( Folder to store the parquet file in. If None, use the default cache directory for the dataset. + download_all_files: bool, optional (default=False) + If `True`, download all data found in the bucket to which the description's + ``minio_url`` points, only download the parquet file otherwise. + Returns ------- output_filename : string, optional Location of the Parquet file if successfully downloaded, None otherwise. """ if isinstance(description, dict): - url = description.get("oml:minio_url") + url = cast(str, description.get("oml:minio_url")) did = description.get("oml:id") elif isinstance(description, OpenMLDataset): - url = description._minio_url + url = cast(str, description._minio_url) did = description.dataset_id else: raise TypeError("`description` should be either OpenMLDataset or Dict.") if cache_directory is None: cache_directory = _create_cache_directory_for_id(DATASETS_CACHE_DIR_NAME, did) - output_file_path = os.path.join(cache_directory, "dataset.pq") + output_file_path = os.path.join(cache_directory, f"dataset_{did}.pq") + + old_file_path = os.path.join(cache_directory, "dataset.pq") + if os.path.isfile(old_file_path): + os.rename(old_file_path, output_file_path) + + # For this release, we want to be able to force a new download even if the + # parquet file is already present when ``download_all_files`` is set. + # For now, it would be the only way for the user to fetch the additional + # files in the bucket (no function exists on an OpenMLDataset to do this). + if download_all_files: + if url.endswith(".pq"): + url, _ = url.rsplit("/", maxsplit=1) + openml._api_calls._download_minio_bucket(source=cast(str, url), destination=cache_directory) if not os.path.isfile(output_file_path): try: diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 50f449ebb..e6c4fe3ec 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -322,6 +322,15 @@ def test_get_dataset_by_name(self): openml.config.server = self.production_server self.assertRaises(OpenMLPrivateDatasetError, openml.datasets.get_dataset, 45) + @pytest.mark.skip("Feature is experimental, can not test against stable server.") + def test_get_dataset_download_all_files(self): + # openml.datasets.get_dataset(id, download_all_files=True) + # check for expected files + # checking that no additional files are downloaded if + # the default (false) is used, seems covered by + # test_get_dataset_lazy + raise NotImplementedError + def test_get_dataset_uint8_dtype(self): dataset = openml.datasets.get_dataset(1) self.assertEqual(type(dataset), OpenMLDataset)