From 5feb6502d224dc1661e30c5d76f0b134ac7723d9 Mon Sep 17 00:00:00 2001 From: Vitor Avila Date: Wed, 12 Feb 2025 02:20:52 -0300 Subject: [PATCH] feat: Update database permissions in async mode --- .../src/pages/DatabaseList/index.tsx | 65 +++ superset/commands/database/exceptions.py | 6 +- .../commands/database/resync_permissions.py | 273 ++++++++-- .../database/resync_permissions_async.py | 101 ++++ superset/commands/database/test_connection.py | 18 +- superset/commands/database/update.py | 223 +-------- superset/commands/database/utils.py | 18 + superset/config.py | 9 + superset/constants.py | 1 + superset/databases/api.py | 20 +- superset/tasks/permissions.py | 49 +- superset/views/base.py | 1 + .../integration_tests/databases/api_tests.py | 174 +++++++ .../resync_permissions_async_test.py | 151 ++++++ .../databases/resync_permissions_test.py | 467 ++++++++++++++++++ .../commands/databases/update_test.py | 16 +- .../commands/databases/utils_test.py | 85 ++++ 17 files changed, 1386 insertions(+), 291 deletions(-) create mode 100644 superset/commands/database/resync_permissions_async.py create mode 100644 tests/unit_tests/commands/databases/resync_permissions_async_test.py create mode 100644 tests/unit_tests/commands/databases/resync_permissions_test.py create mode 100644 tests/unit_tests/commands/databases/utils_test.py diff --git a/superset-frontend/src/pages/DatabaseList/index.tsx b/superset-frontend/src/pages/DatabaseList/index.tsx index 776dbbe817aa3..3af9ab41453af 100644 --- a/superset-frontend/src/pages/DatabaseList/index.tsx +++ b/superset-frontend/src/pages/DatabaseList/index.tsx @@ -71,6 +71,7 @@ interface DatabaseDeleteObject extends DatabaseObject { interface DatabaseListProps { addDangerToast: (msg: string) => void; addSuccessToast: (msg: string) => void; + addInfoToast: (msg: string) => void; user: { userId: string | number; firstName: string; @@ -101,6 +102,7 @@ function BooleanDisplay({ value }: { value: Boolean }) { function DatabaseList({ addDangerToast, + addInfoToast, addSuccessToast, user, }: DatabaseListProps) { @@ -121,6 +123,9 @@ function DatabaseList({ const fullUser = useSelector( state => state.user, ); + const shouldResyncPermsInAsyncMode = useSelector( + state => state.common?.conf.RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE, + ); const showDatabaseModal = getUrlParam(URL_PARAMS.showDatabaseModal); const [query, setQuery] = useQueryParams({ @@ -426,6 +431,49 @@ function DatabaseList({ handleDatabaseEditModal({ database: original, modalOpen: true }); const handleDelete = () => openDatabaseDeleteModal(original); const handleExport = () => handleDatabaseExport(original); + const handleResync = () => { + shouldResyncPermsInAsyncMode + ? addInfoToast( + t('Validating connectivity for %s', original.database_name), + ) + : addInfoToast( + t('Resyncing permissions for %s', original.database_name), + ); + SupersetClient.post({ + endpoint: `/api/v1/database/${original.id}/resync_permissions/`, + }) + .then(({ response, json }) => { + // Sync request + if (response.status === 200) { + addSuccessToast( + t( + 'Permissions successfully resynced for %s', + original.database_name, + ), + ); + } + // Async request + else { + addInfoToast( + t( + 'Syncing permissions for %s in the background', + original.database_name, + ), + ); + } + }) + .catch( + createErrorHandler(errMsg => + addDangerToast( + t( + 'An error occurred while resyncing permissions for %s: %s', + original.database_name, + errMsg, + ), + ), + ), + ); + }; if (!canEdit && !canDelete && !canExport) { return null; } @@ -481,6 +529,23 @@ function DatabaseList({ )} + {canEdit && ( + + + + + + )} ); }, diff --git a/superset/commands/database/exceptions.py b/superset/commands/database/exceptions.py index b80d7acfbe6ce..f3c3988626156 100644 --- a/superset/commands/database/exceptions.py +++ b/superset/commands/database/exceptions.py @@ -88,9 +88,9 @@ def __init__(self, key: str = "") -> None: ) -class DatabaseConnectionNotWorkingError(CommandException): - status = 400 - message = _("DB Connection not working, please check your connection settings.") +class DatabaseConnectionResyncPermissionsError(CommandException): + status = 500 + message = _("Unable to resync permissions for this database connection.") class DatabaseNotFoundError(CommandException): diff --git a/superset/commands/database/resync_permissions.py b/superset/commands/database/resync_permissions.py index 0f0b44f55a50a..05c5eb5a58a9a 100644 --- a/superset/commands/database/resync_permissions.py +++ b/superset/commands/database/resync_permissions.py @@ -17,65 +17,260 @@ from __future__ import annotations import logging -from contextlib import closing -from sqlite3 import ProgrammingError - -from flask import current_app as app -from sqlalchemy.engine import Engine +from functools import partial +from typing import Iterable from superset import security_manager from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( - DatabaseConnectionNotWorkingError, + DatabaseConnectionFailedError, + DatabaseConnectionResyncPermissionsError, DatabaseNotFoundError, - UserNotFoundError, ) +from superset.commands.database.utils import ping from superset.daos.database import DatabaseDAO +from superset.daos.dataset import DatasetDAO +from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.db_engine_specs.base import GenericDBException +from superset.exceptions import OAuth2RedirectError from superset.models.core import Database -from superset.utils.core import timeout +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) class ResyncPermissionsCommand(BaseCommand): - def __init__(self, model_id: int, username: str | None): - self._model_id = model_id - self._username: str | None = username - self._model: Database | None = None + """ + Command to resync database permissions. + """ - def run(self) -> None: - self.validate() - - def validate(self) -> None: + def __init__( + self, + model_id: int, + old_db_connection_name: str | None = None, + db_connection: Database | None = None, + ssh_tunnel: SSHTunnel | None = None, + ): """ - Validates the command. + Constructor method. """ - self._model = DatabaseDAO.find_by_id(self._model_id) - if not self._model: - raise DatabaseNotFoundError() + self.db_connection_id = model_id + self.old_db_connection_name: str | None = old_db_connection_name + self.db_connection: Database | None = db_connection + self.db_connection_ssh_tunnel: SSHTunnel | None = ssh_tunnel - if not self._username or not security_manager.get_user_by_username( - self._username - ): - raise UserNotFoundError() + def validate(self) -> None: + if not self.db_connection: + database = DatabaseDAO.find_by_id(self.db_connection_id) + if not database: + raise DatabaseNotFoundError() + self.db_connection = database + + if not self.old_db_connection_name: + self.old_db_connection_name = self.db_connection.database_name - # Make sure the connection works before delegating the task - def ping(engine: Engine) -> bool: - with closing(engine.raw_connection()) as conn: - return engine.dialect.do_ping(conn) + if not self.db_connection_ssh_tunnel: + self.db_connection_ssh_tunnel = DatabaseDAO.get_ssh_tunnel( + self.db_connection_id + ) - with self._model.get_sqla_engine() as engine: + with self.db_connection.get_sqla_engine() as engine: try: - time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"] - with timeout(int(time_delta.total_seconds())): - alive = ping(engine) - except (ProgrammingError, RuntimeError): - logger.warning("Raw connection failed, retrying with engine") - alive = engine.dialect.do_ping(engine) + alive = ping(engine) except Exception as err: - logger.error("Could not stablish a DB connection") - raise DatabaseConnectionNotWorkingError() from err + raise DatabaseConnectionFailedError() from err if not alive: - logger.error("Could not stablish a DB connection") - raise DatabaseConnectionNotWorkingError() + raise DatabaseConnectionFailedError() + + @transaction( + on_error=partial(on_error, reraise=DatabaseConnectionResyncPermissionsError) + ) + def run(self) -> None: + """ + Resyncs the permissions for a DB connection. + """ + self.validate() + + # Make mypy happy (these are already checked in validate) + assert self.db_connection + assert self.old_db_connection_name + + catalogs = ( + self._get_catalog_names(self.db_connection) + if self.db_connection.db_engine_spec.supports_catalog + else [None] + ) + + for catalog in catalogs: + try: + schemas = self._get_schema_names(self.db_connection, catalog) + + if catalog: + perm = security_manager.get_catalog_perm( + self.old_db_connection_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if not existing_pvm: + # new catalog + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm( + self.db_connection.database_name, + catalog, + ), + ) + for schema in schemas: + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + self.db_connection.database_name, + catalog, + schema, + ), + ) + continue + except DatabaseConnectionFailedError: + # more than one catalog, move to next + if catalog: + logger.warning("Error processing catalog %s", catalog) + continue + raise + + # add possible new schemas in catalog + self._refresh_schemas( + self.old_db_connection_name, + self.db_connection.database_name, + catalog, + schemas, + ) + + if self.old_db_connection_name != self.db_connection.database_name: + self._rename_database_in_permissions( + self.old_db_connection_name, + self.db_connection.database_name, + catalog, + schemas, + ) + + def _get_catalog_names(self, db_connection: Database) -> set[str]: + """ + Helper method to load catalogs. + """ + try: + return db_connection.get_all_catalog_names( + force=True, + ssh_tunnel=self.db_connection_ssh_tunnel, + ) + except OAuth2RedirectError: + # raise OAuth2 exceptions as-is + raise + except GenericDBException as ex: + raise DatabaseConnectionFailedError() from ex + + def _get_schema_names( + self, db_connection: Database, catalog: str | None + ) -> set[str]: + """ + Helper method to load schemas. + """ + try: + return db_connection.get_all_schema_names( + force=True, + catalog=catalog, + ssh_tunnel=self.db_connection_ssh_tunnel, + ) + except OAuth2RedirectError: + # raise OAuth2 exceptions as-is + raise + except GenericDBException as ex: + raise DatabaseConnectionFailedError() from ex + + def _refresh_schemas( + self, + old_db_connection_name: str, + new_db_connection_name: str, + catalog: str | None, + schemas: Iterable[str], + ) -> None: + """ + Add new schemas that don't have permissions yet. + """ + for schema in schemas: + perm = security_manager.get_schema_perm( + old_db_connection_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if not existing_pvm: + new_name = security_manager.get_schema_perm( + new_db_connection_name, + catalog, + schema, + ) + security_manager.add_permission_view_menu("schema_access", new_name) + + def _rename_database_in_permissions( + self, + old_db_connection_name: str, + new_db_connection_name: str, + catalog: str | None, + schemas: Iterable[str], + ) -> None: + new_catalog_perm_name = security_manager.get_catalog_perm( + new_db_connection_name, + catalog, + ) + + # rename existing catalog permission + if catalog: + perm = security_manager.get_catalog_perm( + old_db_connection_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = new_catalog_perm_name + + for schema in schemas: + new_schema_perm_name = security_manager.get_schema_perm( + new_db_connection_name, + catalog, + schema, + ) + + # rename existing schema permission + perm = security_manager.get_schema_perm( + old_db_connection_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = new_schema_perm_name + + # rename permissions on datasets and charts + for dataset in DatabaseDAO.get_datasets( + self.db_connection_id, + catalog=catalog, + schema=schema, + ): + dataset.catalog_perm = new_catalog_perm_name + dataset.schema_perm = new_schema_perm_name + for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: + chart.catalog_perm = new_catalog_perm_name + chart.schema_perm = new_schema_perm_name diff --git a/superset/commands/database/resync_permissions_async.py b/superset/commands/database/resync_permissions_async.py new file mode 100644 index 0000000000000..be18f2a6296e3 --- /dev/null +++ b/superset/commands/database/resync_permissions_async.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +from superset import security_manager +from superset.commands.base import BaseCommand +from superset.commands.database.exceptions import ( + DatabaseConnectionFailedError, + DatabaseNotFoundError, + UserNotFoundError, +) +from superset.commands.database.utils import ping +from superset.daos.database import DatabaseDAO +from superset.tasks.permissions import resync_database_permissions + +logger = logging.getLogger(__name__) + + +class ResyncPermissionsAsyncCommand(BaseCommand): + """ + Command to trigger an async task to resync database permissions. + """ + + def __init__( + self, + model_id: int, + username: str | None, + old_db_connection_name: str | None = None, + ): + """ + Constructor method. + """ + self.db_connection_id = model_id + self.username = username + self.old_db_connection_name = old_db_connection_name + + def validate(self) -> None: + """ + Validates the command before triggering the async task. + + Confirms both the DB connection user exist. Also tests the DB connection. + """ + database = DatabaseDAO.find_by_id(self.db_connection_id) + if not database: + raise DatabaseNotFoundError() + + if not self.old_db_connection_name: + self.old_db_connection_name = database.database_name + + if not self.username or not security_manager.get_user_by_username( + self.username + ): + raise UserNotFoundError() + + with database.get_sqla_engine() as engine: + # Make sure the connection works before delegating the task + try: + alive = ping(engine) + except Exception as err: + logger.error("Could not stablish a DB connection") + raise DatabaseConnectionFailedError() from err + + if not alive: + logger.error("Could not stablish a DB connection") + raise DatabaseConnectionFailedError() + + def trigger_task(self) -> None: + """ + Triggers the async task. + + Delegates Celery to trigger the permission sync using the + ResyncPermissionsCommand command. + """ + resync_database_permissions.delay( + self.db_connection_id, + self.username, + self.old_db_connection_name, + ) + + def run(self) -> None: + """ + Triggers the command validation, and if successful, triggers the async task. + """ + self.validate() + self.trigger_task() diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 6d3219253eaaf..3c16730d00b19 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -15,13 +15,9 @@ # specific language governing permissions and limitations # under the License. import logging -import sqlite3 -from contextlib import closing from typing import Any, Optional -from flask import current_app as app from flask_babel import gettext as _ -from sqlalchemy.engine import Engine from sqlalchemy.exc import DBAPIError, NoSuchModuleError from superset import is_feature_enabled @@ -35,6 +31,7 @@ SSHTunnelDatabasePortError, SSHTunnelingNotEnabledError, ) +from superset.commands.database.utils import ping from superset.daos.database import DatabaseDAO, SSHTunnelDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -47,7 +44,6 @@ ) from superset.extensions import event_logger from superset.models.core import Database -from superset.utils import core as utils from superset.utils.ssh_tunnel import unmask_password_info logger = logging.getLogger(__name__) @@ -136,19 +132,9 @@ def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches engine=database.db_engine_spec.__name__, ) - def ping(engine: Engine) -> bool: - with closing(engine.raw_connection()) as conn: - return engine.dialect.do_ping(conn) - with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine: try: - time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"] - with utils.timeout(int(time_delta.total_seconds())): - alive = ping(engine) - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``utils.timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(engine) + alive = ping(engine) except SupersetTimeoutException as ex: raise SupersetTimeoutException( error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index fbf90694f48e6..5f703182894b9 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -21,17 +21,21 @@ from functools import partial from typing import Any +from flask import current_app as app from flask_appbuilder.models.sqla import Model -from superset import is_feature_enabled, security_manager +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( - DatabaseConnectionFailedError, DatabaseExistsValidationError, DatabaseInvalidError, DatabaseNotFoundError, DatabaseUpdateFailedError, ) +from superset.commands.database.resync_permissions import ResyncPermissionsCommand +from superset.commands.database.resync_permissions_async import ( + ResyncPermissionsAsyncCommand, +) from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( @@ -39,12 +43,11 @@ ) from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.daos.database import DatabaseDAO -from superset.daos.dataset import DatasetDAO from superset.databases.ssh_tunnel.models import SSHTunnel -from superset.db_engine_specs.base import GenericDBException from superset.exceptions import OAuth2RedirectError from superset.models.core import Database from superset.utils import json +from superset.utils.core import get_username from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -87,8 +90,23 @@ def run(self) -> Model: database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) + async_resync_perms = app.config["RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE"] try: - self._refresh_catalogs(database, original_database_name, ssh_tunnel) + if async_resync_perms: + current_username = get_username() + ResyncPermissionsAsyncCommand( + self._model_id, + current_username, + old_db_connection_name=original_database_name, + ).run() + + else: + ResyncPermissionsCommand( + self._model_id, + old_db_connection_name=original_database_name, + db_connection=database, + ssh_tunnel=ssh_tunnel, + ).run() except OAuth2RedirectError: pass @@ -153,201 +171,6 @@ def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: ssh_tunnel_properties, ).run() - def _get_catalog_names( - self, - database: Database, - ssh_tunnel: SSHTunnel | None, - ) -> set[str]: - """ - Helper method to load catalogs. - """ - try: - return database.get_all_catalog_names( - force=True, - ssh_tunnel=ssh_tunnel, - ) - except OAuth2RedirectError: - # raise OAuth2 exceptions as-is - raise - except GenericDBException as ex: - raise DatabaseConnectionFailedError() from ex - - def _get_schema_names( - self, - database: Database, - catalog: str | None, - ssh_tunnel: SSHTunnel | None, - ) -> set[str]: - """ - Helper method to load schemas. - """ - try: - return database.get_all_schema_names( - force=True, - catalog=catalog, - ssh_tunnel=ssh_tunnel, - ) - except OAuth2RedirectError: - # raise OAuth2 exceptions as-is - raise - except GenericDBException as ex: - raise DatabaseConnectionFailedError() from ex - - def _refresh_catalogs( - self, - database: Database, - original_database_name: str, - ssh_tunnel: SSHTunnel | None, - ) -> None: - """ - Add permissions for any new catalogs and schemas. - """ - catalogs = ( - self._get_catalog_names(database, ssh_tunnel) - if database.db_engine_spec.supports_catalog - else [None] - ) - - for catalog in catalogs: - try: - schemas = self._get_schema_names(database, catalog, ssh_tunnel) - - if catalog: - perm = security_manager.get_catalog_perm( - original_database_name, - catalog, - ) - existing_pvm = security_manager.find_permission_view_menu( - "catalog_access", - perm, - ) - if not existing_pvm: - # new catalog - security_manager.add_permission_view_menu( - "catalog_access", - security_manager.get_catalog_perm( - database.database_name, - catalog, - ), - ) - for schema in schemas: - security_manager.add_permission_view_menu( - "schema_access", - security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ), - ) - continue - except DatabaseConnectionFailedError: - # more than one catalog, move to next - if catalog: - logger.warning("Error processing catalog %s", catalog) - continue - raise - - # add possible new schemas in catalog - self._refresh_schemas( - database, - original_database_name, - catalog, - schemas, - ) - - if original_database_name != database.database_name: - self._rename_database_in_permissions( - database, - original_database_name, - catalog, - schemas, - ) - - def _refresh_schemas( - self, - database: Database, - original_database_name: str, - catalog: str | None, - schemas: set[str], - ) -> None: - """ - Add new schemas that don't have permissions yet. - """ - for schema in schemas: - perm = security_manager.get_schema_perm( - original_database_name, - catalog, - schema, - ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) - if not existing_pvm: - new_name = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) - security_manager.add_permission_view_menu("schema_access", new_name) - - def _rename_database_in_permissions( - self, - database: Database, - original_database_name: str, - catalog: str | None, - schemas: set[str], - ) -> None: - new_catalog_perm_name = security_manager.get_catalog_perm( - database.database_name, - catalog, - ) - - # rename existing catalog permission - if catalog: - perm = security_manager.get_catalog_perm( - original_database_name, - catalog, - ) - existing_pvm = security_manager.find_permission_view_menu( - "catalog_access", - perm, - ) - if existing_pvm: - existing_pvm.view_menu.name = new_catalog_perm_name - - for schema in schemas: - new_schema_perm_name = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) - - # rename existing schema permission - perm = security_manager.get_schema_perm( - original_database_name, - catalog, - schema, - ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) - if existing_pvm: - existing_pvm.view_menu.name = new_schema_perm_name - - # rename permissions on datasets and charts - for dataset in DatabaseDAO.get_datasets( - database.id, - catalog=catalog, - schema=schema, - ): - dataset.catalog_perm = new_catalog_perm_name - dataset.schema_perm = new_schema_perm_name - for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: - chart.catalog_perm = new_catalog_perm_name - chart.schema_perm = new_schema_perm_name - def validate(self) -> None: if database_name := self._properties.get("database_name"): if not DatabaseDAO.validate_update_uniqueness( diff --git a/superset/commands/database/utils.py b/superset/commands/database/utils.py index ea0ce1a27e274..88b7dde367568 100644 --- a/superset/commands/database/utils.py +++ b/superset/commands/database/utils.py @@ -17,15 +17,33 @@ from __future__ import annotations import logging +import sqlite3 +from contextlib import closing + +from flask import current_app as app +from sqlalchemy.engine import Engine from superset import security_manager from superset.databases.ssh_tunnel.models import SSHTunnel from superset.db_engine_specs.base import GenericDBException from superset.models.core import Database +from superset.utils.core import timeout logger = logging.getLogger(__name__) +def ping(engine: Engine) -> bool: + try: + time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"] + with timeout(int(time_delta.total_seconds())): + with closing(engine.raw_connection()) as conn: + return engine.dialect.do_ping(conn) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``utils.timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + return engine.dialect.do_ping(engine) + + def add_permissions(database: Database, ssh_tunnel: SSHTunnel | None) -> None: """ Add DAR for catalogs and schemas. diff --git a/superset/config.py b/superset/config.py index 6362e39aec7e7..8f5c66158fd78 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1916,6 +1916,15 @@ class ExtraDynamicQueryFilters(TypedDict, total=False): CATALOGS_SIMPLIFIED_MIGRATION: bool = False +# When updating a DB connection or manually triggering a resync, the command +# happens in sync mode. If you have a celery worker configured, it's recommended +# to change below config to ``True`` to run this process in async mode. A DB +# connection might have hundreds of catalogs with thousands of schemas each, which +# considerably increases the time to process it. Running it in async mode prevents +# keeping a web API call open for this long. +RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE: bool = False + + # ------------------------------------------------------------------- # * WARNING: STOP EDITING HERE * # ------------------------------------------------------------------- diff --git a/superset/constants.py b/superset/constants.py index 3374b2bd90b51..b13f2bfd52829 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -173,6 +173,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "slack_channels": "write", "put_filters": "write", "put_colors": "write", + "resync_permissions": "write", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/databases/api.py b/superset/databases/api.py index f6c3d6c5489db..f0e578240faea 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -47,6 +47,9 @@ from superset.commands.database.export import ExportDatabasesCommand from superset.commands.database.importers.dispatcher import ImportDatabasesCommand from superset.commands.database.resync_permissions import ResyncPermissionsCommand +from superset.commands.database.resync_permissions_async import ( + ResyncPermissionsAsyncCommand, +) from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelDatabasePortError, @@ -120,7 +123,6 @@ from superset.models.core import Database from superset.sql_parse import Table from superset.superset_typing import FlaskResponse -from superset.tasks.permissions import resync_database_permissions from superset.utils import json from superset.utils.core import ( error_msg_from_exception, @@ -620,7 +622,7 @@ def delete(self, pk: int) -> Response: ) return self.response_422(message=str(ex)) - @expose("//resync-permissions/", methods=("POST",)) + @expose("//resync_permissions/", methods=("POST",)) @protect() @safe @statsd_metrics @@ -659,11 +661,17 @@ def resync_permissions(self, pk: int, **kwargs: Any) -> FlaskResponse: 500: $ref: '#/components/responses/500' """ + async_resync_perms = app.config["RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE"] try: - current_username = get_username() - ResyncPermissionsCommand(pk, current_username).run() - resync_database_permissions.delay(pk, current_username) - return self.response(202, message="OK") + if async_resync_perms: + current_username = get_username() + ResyncPermissionsAsyncCommand(pk, current_username).run() + return self.response( + 202, message="Async task created to resync permissions" + ) + + ResyncPermissionsCommand(pk).run() + return self.response(200, message="Permissions successfully resynced") except DatabaseNotFoundError: return self.response_404() except SupersetException as ex: diff --git a/superset/tasks/permissions.py b/superset/tasks/permissions.py index 4aef3bd457c2c..17b1184f3b5d1 100644 --- a/superset/tasks/permissions.py +++ b/superset/tasks/permissions.py @@ -18,10 +18,10 @@ import logging -from flask import g +from flask import current_app, g from superset import security_manager -from superset.commands.database.update import UpdateDatabaseCommand +from superset.commands.database.resync_permissions import ResyncPermissionsCommand from superset.daos.database import DatabaseDAO from superset.extensions import celery_app @@ -29,25 +29,28 @@ @celery_app.task(name="resync_database_permissions", soft_time_limit=600) -def resync_database_permissions(database_id: int, username: str) -> None: +def resync_database_permissions( + database_id: int, username: str, original_database_name: str +) -> None: logger.info("Resyncing permissions for DB connection ID %s", database_id) - if user := security_manager.get_user_by_username(username): - g.user = user - logger.info("Impersonating user ID %s", g.user.id) - else: - logger.error("No user to impersonate/validate permissions") - return - database = DatabaseDAO.find_by_id(database_id) - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database_id) - if not database: - logger.error("Database ID %s not found", database_id) - return - cmmd = UpdateDatabaseCommand(database_id, {}) - try: - cmmd._refresh_catalogs(database, database.name, ssh_tunnel) - except Exception: - logger.error( - "An error occurred while resyncing permissions for DB connection ID %s", - database_id, - exc_info=True, - ) + with current_app.test_request_context(): + try: + user = security_manager.get_user_by_username(username) + assert user + g.user = user + logger.info("Impersonating user ID %s", g.user.id) + db_connection = DatabaseDAO.find_by_id(database_id) + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database_id) + cmmd = ResyncPermissionsCommand( + database_id, + old_db_connection_name=original_database_name, + db_connection=db_connection, + ssh_tunnel=ssh_tunnel, + ) + cmmd.run() + except Exception: + logger.error( + "An error occurred while resyncing permissions for DB connection ID %s", + database_id, + exc_info=True, + ) diff --git a/superset/views/base.py b/superset/views/base.py index bc1b5720895b7..f6184586b3eac 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -107,6 +107,7 @@ "PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET", "JWT_ACCESS_CSRF_COOKIE_NAME", "SQLLAB_QUERY_RESULT_TIMEOUT", + "RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE", ) logger = logging.getLogger(__name__) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 6c34942f61c12..7ab71ea76fce9 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -50,6 +50,7 @@ from superset.utils.database import get_example_database, get_main_database from superset.utils import json from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.conftest import with_config from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_USERNAME from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 @@ -4070,3 +4071,176 @@ def _base_filter(query): db.session.delete(first_model) db.session.delete(second_model) db.session.commit() + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) + @mock.patch( + "superset.commands.database.resync_permissions.ResyncPermissionsCommand.run" + ) + def test_resync_db_perms_sync(self, mock_cmmd): + """ + Database API: Test resync permissions in sync mode. + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + assert response == {"message": "Permissions successfully resynced"} + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) + @mock.patch("superset.commands.database.resync_permissions.DatabaseDAO.find_by_id") + def test_resync_db_perms_sync_db_not_found(self, mock_find_db): + """ + Database API: Test resync permissions in sync mode when the DB connection + is not found. + """ + self.login(ADMIN_USERNAME) + mock_find_db.return_value = None + + uri = "api/v1/database/10/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 404 + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) + @mock.patch("superset.commands.database.resync_permissions.ping") + def test_resync_db_perms_sync_db_connection_failed(self, mock_ping): + """ + Database API: Test resync permissions in sync mode when the DB connection + is not working. + """ + self.login(ADMIN_USERNAME) + mock_ping.return_value = False + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 500 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch( + "superset.commands.database.resync_permissions_async.ResyncPermissionsAsyncCommand.run" + ) + def test_resync_db_perms_async(self, mock_cmmd): + """ + Database API: Test resync permissions in async mode. + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 202 + response = json.loads(rv.data.decode("utf-8")) + assert response == {"message": "Async task created to resync permissions"} + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch( + "superset.commands.database.resync_permissions_async.DatabaseDAO.find_by_id" + ) + def test_resync_db_perms_async_db_not_found(self, mock_find_db): + """ + Database API: Test resync permissions in async mode when the DB connection + is not found. + """ + self.login(ADMIN_USERNAME) + mock_find_db.return_value = None + + uri = "api/v1/database/10/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 404 + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch("superset.commands.database.resync_permissions_async.ping") + def test_resync_db_perms_async_db_connection_failed(self, mock_ping): + """ + Database API: Test resync permissions in async mode when the DB connection + is not working. + """ + self.login(ADMIN_USERNAME) + mock_ping.return_value = False + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 500 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @with_config({"RESYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch( + "superset.commands.database.resync_permissions_async.security_manager.get_user_by_username" + ) + def test_resync_db_perms_async_user_not_found(self, mock_get_user): + """ + Database API: Test resync permissions in async mode when the user to be + impersonated can't be found. + """ + self.login(ADMIN_USERNAME) + mock_get_user.return_value = False + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 400 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.commands.database.resync_permissions.ResyncPermissionsCommand.run" + ) + def test_resync_db_perms_no_access(self, mock_cmmd): + """ + Database API: Test resync permissions with a user without permission to do so. + """ + self.login(GAMMA_USERNAME) + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/resync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 403 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() diff --git a/tests/unit_tests/commands/databases/resync_permissions_async_test.py b/tests/unit_tests/commands/databases/resync_permissions_async_test.py new file mode 100644 index 0000000000000..af65fb6ca0560 --- /dev/null +++ b/tests/unit_tests/commands/databases/resync_permissions_async_test.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.exceptions import ( + DatabaseConnectionFailedError, + DatabaseNotFoundError, + UserNotFoundError, +) +from superset.commands.database.resync_permissions_async import ( + ResyncPermissionsAsyncCommand, +) + + +def test_resync_permissions_async_command_validate(mocker: MockerFixture) -> None: + """ + Test the ``validate`` method. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "Connection Name" + mocker.patch( + "superset.commands.database.resync_permissions_async.DatabaseDAO.find_by_id", + return_value=mock_db, + ) + mocker.patch( + "superset.commands.database.resync_permissions_async.security_manager.get_user_by_username", + ) + mocker.patch( + "superset.commands.database.resync_permissions_async.ping", return_value=True + ) + + command = ResyncPermissionsAsyncCommand(1, "username") + command.validate() + + # Asserts + assert command.db_connection_id == 1 + assert command.username == "username" + assert command.old_db_connection_name == "Connection Name" + + +def test_resync_permissions_async_command_validate_new_db_name(mocker: MockerFixture): + """ + Test the ``validate`` method when the DB connection has a new name. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "Connection Name" + mocker.patch( + "superset.commands.database.resync_permissions_async.DatabaseDAO.find_by_id", + return_value=mock_db, + ) + mocker.patch( + "superset.commands.database.resync_permissions_async.security_manager.get_user_by_username", + ) + mocker.patch( + "superset.commands.database.resync_permissions_async.ping", return_value=True + ) + + command = ResyncPermissionsAsyncCommand( + 1, "username", old_db_connection_name="Old Connection Name" + ) + command.validate() + + # Asserts + assert command.db_connection_id == 1 + assert command.username == "username" + assert command.old_db_connection_name == "Old Connection Name" + + +def test_resync_permissions_async_command_validate_database_not_found( + mocker: MockerFixture, +) -> None: + """ + Test the ``validate`` method when the database connection is not found. + """ + mocker.patch( + "superset.commands.database.resync_permissions_async.DatabaseDAO.find_by_id", + return_value=None, + ) + + command = ResyncPermissionsAsyncCommand(1, "username") + with pytest.raises(DatabaseNotFoundError): + command.validate() + + +def test_resync_permissions_async_command_validate_user_not_found( + mocker: MockerFixture, +) -> None: + """ + Test the ``validate`` method when the user is not found. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "Connection Name" + mocker.patch( + "superset.commands.database.resync_permissions_async.DatabaseDAO.find_by_id", + return_value=mock_db, + ) + mocker.patch( + "superset.commands.database.resync_permissions_async.security_manager.get_user_by_username", + return_value=None, + ) + + command = ResyncPermissionsAsyncCommand(1, "username") + with pytest.raises(UserNotFoundError): + command.validate() + + +def test_reynsc_permissions_async_command_validate_db_connection_error( + mocker: MockerFixture, +): + """ + Test the ``validate`` method when the database connection fails. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "Connection Name" + mocker.patch( + "superset.commands.database.resync_permissions_async.DatabaseDAO.find_by_id", + return_value=mock_db, + ) + mocker.patch( + "superset.commands.database.resync_permissions_async.security_manager.get_user_by_username", + ) + mock_ping = mocker.patch( + "superset.commands.database.resync_permissions_async.ping", return_value=False + ) + + command = ResyncPermissionsAsyncCommand(1, "username") + with pytest.raises(DatabaseConnectionFailedError): + command.validate() + + mock_ping.reset_mock() + mock_ping.side_effect = Exception + + with pytest.raises(DatabaseConnectionFailedError): + command.validate() diff --git a/tests/unit_tests/commands/databases/resync_permissions_test.py b/tests/unit_tests/commands/databases/resync_permissions_test.py new file mode 100644 index 0000000000000..b5721b1b72473 --- /dev/null +++ b/tests/unit_tests/commands/databases/resync_permissions_test.py @@ -0,0 +1,467 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.exceptions import DatabaseConnectionFailedError +from superset.commands.database.resync_permissions import ResyncPermissionsCommand +from superset.db_engine_specs.base import GenericDBException +from superset.exceptions import OAuth2RedirectError + + +def test_resync_permissions_command_validate(mocker: MockerFixture): + """ + Test the ``validate`` method. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "current name" + mock_ssh = mocker.MagicMock() + mock_databasedao = mocker.patch( + "superset.commands.database.resync_permissions.DatabaseDAO" + ) + mock_databasedao.find_by_id.return_value = mock_db + mock_databasedao.get_ssh_tunnel.return_value = mock_ssh + mocker.patch( + "superset.commands.database.resync_permissions.ping", return_value=True + ) + + cmmd = ResyncPermissionsCommand(1) + cmmd.validate() + + assert cmmd.db_connection == mock_db + assert cmmd.old_db_connection_name == "current name" + assert cmmd.db_connection_ssh_tunnel == mock_ssh + mock_databasedao.find_by_id.assert_called_once_with(1) + mock_databasedao.get_ssh_tunnel.assert_called_once_with(1) + + +def test_resync_permissions_command_validate_passing_all_values(mocker: MockerFixture): + """ + Test the ``validate`` method when providing all arguments to the constructor. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "current name" + mock_ssh = mocker.MagicMock() + mock_databasedao = mocker.patch( + "superset.commands.database.resync_permissions.DatabaseDAO" + ) + mocker.patch( + "superset.commands.database.resync_permissions.ping", return_value=True + ) + + cmmd = ResyncPermissionsCommand( + 1, + old_db_connection_name="old name", + db_connection=mock_db, + ssh_tunnel=mock_ssh, + ) + cmmd.validate() + + assert cmmd.db_connection == mock_db + assert cmmd.old_db_connection_name == "old name" + assert cmmd.db_connection_ssh_tunnel == mock_ssh + mock_databasedao.find_by_id.assert_not_called() + mock_databasedao.get_ssh_tunnel.assert_not_called() + + +def test_resync_permissions_command_validate_raise(mocker: MockerFixture): + """ + Test the ``validate`` method when an exception is raised. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "current name" + mock_ssh = mocker.MagicMock() + mock_ping = mocker.patch( + "superset.commands.database.resync_permissions.ping", return_value=False + ) + + cmmd = ResyncPermissionsCommand( + 1, + db_connection=mock_db, + ssh_tunnel=mock_ssh, + ) + with pytest.raises(DatabaseConnectionFailedError): + cmmd.validate() + + mock_ping.reset_mock() + mock_ping.side_effect = Exception + + with pytest.raises(DatabaseConnectionFailedError): + cmmd.validate() + + +def test_resync_permissions_command_run(mocker: MockerFixture): + """ + Test the ``_refresh_catalogs`` method. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "same_name" + mock_db.db_engine_spec.supports_catalog = True + find_pvm_mock = mocker.patch( + "superset.commands.database.utils.security_manager.find_permission_view_menu" + ) + add_pvm_mock = mocker.patch( + "superset.commands.database.utils.security_manager.add_permission_view_menu" + ) + find_pvm_mock.side_effect = [mocker.MagicMock(), None] + schemas_list = [ + ["schema1_catalog_1", "schema2_catalog_1"], + ["schema1_catalog_2", "schema2_catalog_2"], + ] + + cmmd = ResyncPermissionsCommand(1, "same_name", mock_db, None) + mocker.patch.object( + cmmd, "_get_catalog_names", return_value=["catalog1", "catalog2"] + ) + mocker.patch.object(cmmd, "_get_schema_names", side_effect=schemas_list) + mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas") + mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions") + cmmd.run() + + add_pvm_mock.assert_has_calls( + [ + mocker.call("catalog_access", "[same_name].[catalog2]"), + mocker.call("schema_access", "[same_name].[catalog2].[schema1_catalog_2]"), + mocker.call("schema_access", "[same_name].[catalog2].[schema2_catalog_2]"), + ] + ) + mock_refresh_schemas.assert_called_once_with( + "same_name", + "same_name", + "catalog1", + ["schema1_catalog_1", "schema2_catalog_1"], + ) + mock_rename_db_perm.assert_not_called() + + +def test_resync_permissions_command_run_raise_on_getting_schemas(mocker: MockerFixture): + """ + Test the ``run`` method when an exception is raised on getting the schemas + for the catalog. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "same_name" + mock_db.db_engine_spec.supports_catalog = True + find_pvm_mock = mocker.patch( + "superset.commands.database.utils.security_manager.find_permission_view_menu" + ) + add_pvm_mock = mocker.patch( + "superset.commands.database.utils.security_manager.add_permission_view_menu" + ) + find_pvm_mock.return_value = mocker.MagicMock() + schemas_list = [ + DatabaseConnectionFailedError, + ["schema1_catalog_2", "schema2_catalog_2"], + ] + + cmmd = ResyncPermissionsCommand(1, "same_name", mock_db, None) + mocker.patch.object( + cmmd, "_get_catalog_names", return_value=["catalog1", "catalog2"] + ) + mocker.patch.object(cmmd, "_get_schema_names", side_effect=schemas_list) + mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas") + mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions") + cmmd.run() + + add_pvm_mock.assert_not_called() + mock_refresh_schemas.assert_called_once_with( + "same_name", + "same_name", + "catalog2", + ["schema1_catalog_2", "schema2_catalog_2"], + ) + mock_rename_db_perm.assert_not_called() + + +def test_resync_permissions_command_run_new_db_name(mocker: MockerFixture): + """ + Test the ``run`` method when the database name has changed. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "New Name" + mock_db.db_engine_spec.supports_catalog = True + mocker.patch( + "superset.commands.database.utils.security_manager.find_permission_view_menu", + return_value=mocker.MagicMock(), + ) + + cmmd = ResyncPermissionsCommand(1, "Old Name", mock_db, None) + mocker.patch.object(cmmd, "_get_catalog_names", return_value=["catalog"]) + mocker.patch.object(cmmd, "_get_schema_names", return_value=["schema"]) + mocker.patch.object(cmmd, "_refresh_schemas") + mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions") + cmmd.run() + + mock_rename_db_perm.assert_called_once_with( + "Old Name", "New Name", "catalog", ["schema"] + ) + + +def test_resync_permissions_command_run_no_catalog(mocker: MockerFixture): + """ + Test the ``run`` method when the DB connection does not supports catalogs. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "Name" + mock_db.db_engine_spec.supports_catalog = False + + cmmd = ResyncPermissionsCommand(1, "Name", mock_db, None) + mocker.patch.object(cmmd, "_get_schema_names", return_value=["schema"]) + mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas") + cmmd.run() + + mock_refresh_schemas.assert_called_once_with("Name", "Name", None, ["schema"]) + + +def test_resync_permissions_command_run_no_catalog_raise_on_getting_schemas( + mocker: MockerFixture, +): + """ + Test the ``run`` method when an exception is raised on getting the schemas + for a DB connection that does not support catalog. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "Name" + mock_db.db_engine_spec.supports_catalog = False + + cmmd = ResyncPermissionsCommand(1, "Name", mock_db, None) + mocker.patch.object( + cmmd, "_get_schema_names", side_effect=DatabaseConnectionFailedError + ) + with pytest.raises(DatabaseConnectionFailedError): + cmmd.run() + + +def test_resync_permissions_command_run_no_catalog_new_db_name(mocker: MockerFixture): + """ + Test the ``run`` method when the database name has changed and the DB connection + does not support catalog. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "New Name" + mock_db.db_engine_spec.supports_catalog = False + + cmmd = ResyncPermissionsCommand(1, "Name", mock_db, None) + mocker.patch.object(cmmd, "_get_schema_names", return_value=["schema"]) + mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas") + mock_rename_db = mocker.patch.object(cmmd, "_rename_database_in_permissions") + cmmd.run() + + mock_refresh_schemas.assert_called_once_with("Name", "New Name", None, ["schema"]) + mock_rename_db.assert_called_once_with("Name", "New Name", None, ["schema"]) + + +def test_resync_permissions_command_get_catalog_names(mocker: MockerFixture): + """ + Test the ``_get_catalog_names`` method. + """ + mock_db = mocker.MagicMock() + mock_db.get_all_catalog_names.return_value = {"catalog1", "catalog2"} + + cmmd = ResyncPermissionsCommand(1, "DB Connection Name", mock_db, None) + result = cmmd._get_catalog_names(mock_db) + + assert result == {"catalog1", "catalog2"} + mock_db.get_all_catalog_names.assert_called_once_with( + force=True, + ssh_tunnel=None, + ) + + +def test_resync_permissions_command_get_catalog_names_oauth2_exception( + mocker: MockerFixture, +): + """ + Test the ``_get_catalog_names`` method when an OAuth2 exception + is raised. + """ + mock_db = mocker.MagicMock() + mock_db.get_all_catalog_names.side_effect = OAuth2RedirectError( + "Missing token", "mock_tab", "mock_url" + ) + + cmmd = ResyncPermissionsCommand(1, "DB Connection Name", mock_db, None) + with pytest.raises(OAuth2RedirectError): + cmmd._get_catalog_names(mock_db) + + +def test_resync_permissions_command_get_catalog_names_generic_db_exception( + mocker: MockerFixture, +): + """ + Test the ``_get_catalog_names`` method when an OAuth2 exception + is raised. + """ + mock_db = mocker.MagicMock() + mock_db.get_all_catalog_names.side_effect = GenericDBException + + cmmd = ResyncPermissionsCommand(1, "DB Connection Name", mock_db, None) + with pytest.raises(DatabaseConnectionFailedError): + cmmd._get_catalog_names(mock_db) + + +def test_resync_permissions_command_get_schema_names(mocker: MockerFixture): + """ + Test the ``_get_schema_names`` method. + """ + mock_db = mocker.MagicMock() + mock_db.get_all_schema_names.return_value = {"schema1", "schema2"} + + cmmd = ResyncPermissionsCommand(1, "DB Connection Name", mock_db, None) + result = cmmd._get_schema_names(mock_db, "my_catalog") + + assert result == {"schema1", "schema2"} + mock_db.get_all_schema_names.assert_called_once_with( + force=True, + catalog="my_catalog", + ssh_tunnel=None, + ) + + +def test_resync_permissions_command_get_schema_names_oauth2_exception( + mocker: MockerFixture, +): + """ + Test the ``_get_schema_names`` method when an OAuth2 exception + is raised. + """ + mock_db = mocker.MagicMock() + mock_db.get_all_schema_names.side_effect = OAuth2RedirectError( + "Missing token", "mock_tab", "mock_url" + ) + + cmmd = ResyncPermissionsCommand(1, "DB Connection Name", mock_db, None) + with pytest.raises(OAuth2RedirectError): + cmmd._get_schema_names(mock_db, "my_catalog") + + +def test_resync_permissions_command_get_schema_names_generic_db_exception( + mocker: MockerFixture, +): + """ + Test the ``_get_schema_names`` method when an OAuth2 exception + is raised. + """ + mock_db = mocker.MagicMock() + mock_db.get_all_schema_names.side_effect = GenericDBException + + cmmd = ResyncPermissionsCommand(1, "DB Connection Name", mock_db, None) + with pytest.raises(DatabaseConnectionFailedError): + cmmd._get_schema_names(mock_db, None) + + +def test_resync_permissions_command_refresh_schemas(mocker: MockerFixture): + """ + Test the ``_refresh_schemas`` method. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "same_name" + get_schem_perm_mock = mocker.patch( + "superset.commands.database.resync_permissions.security_manager.get_schema_perm" + ) + get_schem_perm_mock.side_effect = [ + "[same_name].[catalog].[schema1]", + None, + "[same_name].[catalog].[schema2]", + ] + find_pvm_mock = mocker.patch( + "superset.commands.database.resync_permissions.security_manager.find_permission_view_menu" + ) + find_pvm_mock.side_effect = [mocker.MagicMock(), None] + add_pvm_mock = mocker.patch( + "superset.commands.database.resync_permissions.security_manager.add_permission_view_menu" + ) + + cmmd = ResyncPermissionsCommand(1, "same_name", mock_db, None) + cmmd._refresh_schemas("same_name", "same_name", "catalog", ["schema1", "schema2"]) + + add_pvm_mock.assert_called_once_with( + "schema_access", "[same_name].[catalog].[schema2]" + ) + + +def test_resync_permissions_command_rename_database_in_permissions( + mocker: MockerFixture, +): + """ + Test the ``_rename_database_in_permissions`` method. + """ + mock_db = mocker.MagicMock() + mock_db.database_name = "new_name" + find_pvm_mock = mocker.patch( + "superset.commands.database.resync_permissions.security_manager.find_permission_view_menu" + ) + get_schema_perm_mock = mocker.patch( + "superset.commands.database.resync_permissions.security_manager.get_schema_perm" + ) + mock_catalog_perm = mocker.MagicMock() + mock_catalog_perm.view_menu.name = "[old_name].[catalog]" + mock_schema_perm = mocker.MagicMock() + mock_schema_perm.view_menu.name = "[old_name].[catalog].[schema1]" + find_pvm_mock.side_effect = [ + mock_catalog_perm, + mock_schema_perm, + None, + ] + get_schema_perm_mock.side_effect = [ + "[new_name].[catalog].[schema1]", + "[old_name].[catalog].[schema1]", + "[new_name].[catalog].[schema2]", + "[old_name].[catalog].[schema2]", + ] + + mock_dataset = mocker.MagicMock() + mock_dataset.id = 1 + mock_dataset.catalog_perm = "[old_name].[catalog]" + mock_dataset.schema_perm = "[old_name].[catalog].[schema1]" + mock_chart = mocker.MagicMock() + mock_chart.catalog_perm = "[old_name].[catalog]" + mock_chart.schema_perm = "[old_name].[catalog].[schema1]" + + mock_database_dao = mocker.patch( + "superset.commands.database.resync_permissions.DatabaseDAO" + ) + mock_database_dao.get_datasets.side_effect = [ + [mock_dataset], + [], + ] + mock_dataset_dao = mocker.patch( + "superset.commands.database.resync_permissions.DatasetDAO" + ) + mock_dataset_dao.get_related_objects.return_value = {"charts": [mock_chart]} + + cmmd = ResyncPermissionsCommand(1, "old_name", mock_db, None) + cmmd._rename_database_in_permissions( + "old_name", "new_name", "catalog", ["schema1", "schema2"] + ) + + find_pvm_mock.assert_has_calls( + [ + mocker.call("catalog_access", "[old_name].[catalog]"), + mocker.call("schema_access", "[old_name].[catalog].[schema1]"), + mocker.call("schema_access", "[old_name].[catalog].[schema2]"), + ] + ) + + assert mock_catalog_perm.view_menu.name == "[new_name].[catalog]" + assert mock_schema_perm.view_menu.name == "[new_name].[catalog].[schema1]" + assert mock_dataset.catalog_perm == "[new_name].[catalog]" + assert mock_dataset.schema_perm == "[new_name].[catalog].[schema1]" + assert mock_chart.catalog_perm == "[new_name].[catalog]" + assert mock_chart.schema_perm == "[new_name].[catalog].[schema1]" diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index daf41b7506888..b8d12ad3a72e7 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -208,6 +208,9 @@ def test_rename_with_catalog( been renamed from `my_db` to `my_other_db`. """ DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 + resync_db_dao = mocker.patch( + "superset.commands.database.resync_permissions.DatabaseDAO" + ) original_database = mocker.MagicMock() original_database.database_name = "my_db" DatabaseDAO.find_by_id.return_value = original_database @@ -216,9 +219,11 @@ def test_rename_with_catalog( dataset = mocker.MagicMock() chart = mocker.MagicMock() - DatabaseDAO.get_datasets.return_value = [dataset] - DatasetDAO = mocker.patch("superset.commands.database.update.DatasetDAO") # noqa: N806 - DatasetDAO.get_related_objects.return_value = {"charts": [chart]} + resync_db_dao.get_datasets.return_value = [dataset] + dataset_dao = mocker.patch( + "superset.commands.database.resync_permissions.DatasetDAO" + ) # noqa: N806 + dataset_dao.get_related_objects.return_value = {"charts": [chart]} find_permission_view_menu = mocker.patch.object( security_manager, @@ -280,12 +285,15 @@ def test_rename_without_catalog( is added. Additionally, the database has been renamed from `my_db` to `my_other_db`. """ # noqa: E501 DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 + resync_db_dao = mocker.patch( + "superset.commands.database.resync_permissions.DatabaseDAO" + ) original_database = mocker.MagicMock() original_database.database_name = "my_db" DatabaseDAO.find_by_id.return_value = original_database database_without_catalog.database_name = "my_other_db" DatabaseDAO.update.return_value = database_without_catalog - DatabaseDAO.get_datasets.return_value = [] + resync_db_dao.get_datasets.return_value = [] find_permission_view_menu = mocker.patch.object( security_manager, diff --git a/tests/unit_tests/commands/databases/utils_test.py b/tests/unit_tests/commands/databases/utils_test.py new file mode 100644 index 0000000000000..793821fea7e4b --- /dev/null +++ b/tests/unit_tests/commands/databases/utils_test.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import sqlite3 +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.utils import ping +from tests.integration_tests.conftest import with_config + + +@pytest.fixture +def mock_engine(mocker: MockerFixture) -> tuple[MagicMock, MagicMock, MagicMock]: + mock_connection = mocker.MagicMock() + mock_engine = mocker.MagicMock() + mock_dialect = mocker.MagicMock() + mock_engine.raw_connection.return_value = mock_connection + mock_engine.dialect = mock_dialect + return mock_engine, mock_connection, mock_dialect + + +@with_config({"TEST_DATABASE_CONNECTION_TIMEOUT": datetime.timedelta(seconds=10)}) +def test_ping_success(mock_engine: MockerFixture): + """ + Test the ``ping`` method. + """ + mock_engine, mock_connection, mock_dialect = mock_engine + mock_dialect.do_ping.return_value = True + + result = ping(mock_engine) + + assert result is True + + mock_engine.raw_connection.assert_called_once() + mock_dialect.do_ping.assert_called_once_with(mock_connection) + + +@with_config({"TEST_DATABASE_CONNECTION_TIMEOUT": datetime.timedelta(seconds=10)}) +def test_ping_sqlite_exception(mocker: MockerFixture, mock_engine: MockerFixture): + """ + Test the ``ping`` method when a sqlite3.ProgrammingError is raised. + """ + mock_engine, mock_connection, mock_dialect = mock_engine + mock_dialect.do_ping.side_effect = [sqlite3.ProgrammingError, True] + + result = ping(mock_engine) + + assert result is True + + mock_dialect.do_ping.assert_has_calls( + [mocker.call(mock_connection), mocker.call(mock_engine)] + ) + + +def test_ping_runtime_exception(mocker: MockerFixture, mock_engine: MockerFixture): + """ + Test the ``ping`` method when a RuntimeError is raised. + """ + mock_engine, _, mock_dialect = mock_engine + mock_timeout = mocker.patch("superset.commands.database.utils.timeout") + mock_timeout.side_effect = RuntimeError("timeout") + mock_dialect.do_ping.return_value = True + + result = ping(mock_engine) + + assert result is True + mock_dialect.do_ping.assert_called_once_with(mock_engine)