Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Glue] Added Connections Support #8626

Merged
merged 4 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 88 additions & 6 deletions moto/glue/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,32 @@ def __init__(
self.last_modified_timestamp = self.created_timestamp
self.status = "READY"
self.availability_zone = "us-east-1a"
self.vpc_id = "vpc-12345678"
self.yarn_endpoint_address = f"yarn-{endpoint_name}.glue.amazonaws.com"
self.private_address = "10.0.0.1"
self.public_address = f"{endpoint_name}.glue.amazonaws.com"
# TODO: Get the vpc id from the subnet using the subnet_id
self.vpc_id = "vpc-12345678" if subnet_id != "subnet-default" else None
self.yarn_endpoint_address = (
f"yarn-{endpoint_name}.glue.amazonaws.com"
if subnet_id == "subnet-default"
else None
)
self.public_address = (
f"{endpoint_name}.glue.amazonaws.com"
if subnet_id == "subnet-default"
else None
)
self.zeppelin_remote_spark_interpreter_port = 9007
self.public_key = pubkey
self.public_keys = [self.public_key]

def as_dict(self) -> Dict[str, Any]:
return {
response = {
"EndpointName": self.endpoint_name,
"RoleArn": self.role_arn,
"SecurityGroupIds": self.security_group_ids,
"SubnetId": self.subnet_id,
"YarnEndpointAddress": self.yarn_endpoint_address,
"PrivateAddress": self.private_address,
"ZeppelinRemoteSparkInterpreterPort": self.zeppelin_remote_spark_interpreter_port,
"PublicAddress": self.public_address,
"Status": self.status,
"WorkerType": self.worker_type,
"GlueVersion": self.glue_version,
Expand All @@ -138,10 +146,19 @@ def as_dict(self) -> Dict[str, Any]:
"SecurityConfiguration": self.security_configuration,
"Arguments": self.arguments,
}
if self.public_address:
response["PublicAddress"] = self.public_address
return response


class GlueBackend(BaseBackend):
PAGINATION_MODEL = {
"get_connections": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "name",
},
"get_jobs": {
"input_token": "next_token",
"limit_key": "max_results",
Expand Down Expand Up @@ -190,6 +207,7 @@ def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.databases: Dict[str, FakeDatabase] = OrderedDict()
self.crawlers: Dict[str, FakeCrawler] = OrderedDict()
self.connections: Dict[str, FakeConnection] = OrderedDict()
self.jobs: Dict[str, FakeJob] = OrderedDict()
self.job_runs: Dict[str, FakeJobRun] = OrderedDict()
self.sessions: Dict[str, FakeSession] = OrderedDict()
Expand Down Expand Up @@ -1179,6 +1197,36 @@ def get_dev_endpoint(self, endpoint_name: str) -> FakeDevEndpoint:
except KeyError:
raise EntityNotFoundException(f"DevEndpoint {endpoint_name} not found")

def create_connection(
self, catalog_id: str, connection_input: Dict[str, Any], tags: Dict[str, str]
) -> str:
name = connection_input.get("Name", "")
if name in self.connections:
raise AlreadyExistsException(f"Connection {name} already exists")
connection = FakeConnection(self, catalog_id, connection_input, tags)
self.connections[name] = connection
return connection.status

def get_connection(
self,
catalog_id: str,
name: str,
hide_password: bool,
apply_override_for_compute_environment: str,
) -> "FakeConnection":
# TODO: Implement filtering
connection = self.connections.get(name)
if not connection:
raise EntityNotFoundException(f"Connection {name} not found")
return connection

@paginate(pagination_model=PAGINATION_MODEL)
def get_connections(
self, catalog_id: str, filter: Dict[str, Any], hide_password: bool
) -> List["FakeConnection"]:
# TODO: Implement filtering
return [connection for connection in self.connections.values()]


class FakeDatabase(BaseModel):
def __init__(
Expand Down Expand Up @@ -1337,7 +1385,7 @@ def __init__(
tags: Dict[str, str],
backend: GlueBackend,
):
self.name = name
self.name: str = name
self.role = role
self.database_name = database_name
self.description = description
Expand Down Expand Up @@ -1869,4 +1917,38 @@ def as_dict(self) -> Dict[str, Any]:
return data


class FakeConnection(BaseModel):
def __init__(
self,
backend: GlueBackend,
catalog_id: str,
connection_input: Dict[str, Any],
tags: Dict[str, str],
) -> None:
self.catalog_id = catalog_id
self.connection_input = connection_input
self.created_time = utcnow()
self.updated_time = utcnow()
self.arn = f"arn:{get_partition(backend.region_name)}:glue:{backend.region_name}:{backend.account_id}:connection/{self.connection_input['Name']}"
self.backend = backend
self.backend.tag_resource(self.arn, tags)
self.status = "READY"
self.name = self.connection_input.get("Name")
self.description = self.connection_input.get("Description")

def as_dict(self) -> Dict[str, Any]:
return {
"Name": self.name,
"Description": self.description,
"Connection": self.connection_input,
"CreationTime": self.created_time.isoformat(),
"LastUpdatedTime": self.updated_time.isoformat(),
"CatalogId": self.catalog_id,
"Status": self.status,
"PhysicalConnectionRequirements": self.connection_input.get(
"PhysicalConnectionRequirements"
),
}


glue_backends = BackendDict(GlueBackend, "glue")
42 changes: 42 additions & 0 deletions moto/glue/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,3 +793,45 @@ def get_dev_endpoint(self) -> str:
endpoint_name = self._get_param("EndpointName")
dev_endpoint = self.glue_backend.get_dev_endpoint(endpoint_name)
return json.dumps({"DevEndpoint": dev_endpoint.as_dict()})

def create_connection(self) -> str:
catalog_id = self._get_param("CatalogId")
connection_input = self._get_param("ConnectionInput")
tags = self._get_param("Tags")
create_connection_status = self.glue_backend.create_connection(
catalog_id=catalog_id,
connection_input=connection_input,
tags=tags,
)
return json.dumps(dict(CreateConnectionStatus=create_connection_status))

def get_connection(self) -> str:
catalog_id = self._get_param("CatalogId")
name = self._get_param("Name")
hide_password = self._get_param("HidePassword")
apply_override_for_compute_environment = self._get_param(
"ApplyOverrideForComputeEnvironment"
)
connection = self.glue_backend.get_connection(
catalog_id=catalog_id,
name=name,
hide_password=hide_password,
apply_override_for_compute_environment=apply_override_for_compute_environment,
)
return json.dumps(dict(Connection=connection.as_dict()))

def get_connections(self) -> str:
catalog_id = self._get_param("CatalogId")
filter = self._get_param("Filter")
hide_password = self._get_param("HidePassword")
next_token = self._get_param("NextToken")
max_results = self._get_param("MaxResults")
connections, next_token = self.glue_backend.get_connections(
catalog_id=catalog_id,
filter=filter,
hide_password=hide_password,
next_token=next_token,
max_results=max_results,
)
connection_list = [connection.as_dict() for connection in connections]
return json.dumps(dict(ConnectionList=connection_list, NextToken=next_token))
86 changes: 86 additions & 0 deletions tests/test_glue/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,92 @@ def test_get_dev_endpoint():
response = client.get_dev_endpoint(EndpointName="test-endpoint")
assert response["DevEndpoint"]["EndpointName"] == "test-endpoint"
assert response["DevEndpoint"]["Status"] == "READY"
assert "PublicAddress" in response["DevEndpoint"]

with pytest.raises(client.exceptions.EntityNotFoundException):
client.get_dev_endpoint(EndpointName="nonexistent")

client.create_dev_endpoint(
EndpointName="test-endpoint-private",
RoleArn="arn:aws:iam::123456789012:role/GlueDevEndpoint",
SubnetId="subnet-1234567890abcdef0",
)

response = client.get_dev_endpoint(EndpointName="test-endpoint-private")
assert response["DevEndpoint"]["EndpointName"] == "test-endpoint-private"
assert "PublicAddress" not in response["DevEndpoint"]


@mock_aws
def test_create_connection():
client = boto3.client("glue", region_name="us-east-2")
subnet_id = "subnet-1234567890abcdef0"
connection_input = {
"Name": "test-connection",
"Description": "Test Connection",
"ConnectionType": "JDBC",
"ConnectionProperties": {"key": "value"},
"PhysicalConnectionRequirements": {
"SubnetId": subnet_id,
"SecurityGroupIdList": [],
"AvailabilityZone": "us-east-1a",
},
}
resp = client.create_connection(ConnectionInput=connection_input)
assert resp["CreateConnectionStatus"] == "READY"

# Test duplicate name
with pytest.raises(client.exceptions.AlreadyExistsException):
client.create_connection(ConnectionInput=connection_input)


@mock_aws
def test_get_connection():
client = boto3.client("glue", region_name="us-east-2")
subnet_id = "subnet-1234567890abcdef0"
connection_input = {
"Name": "test-connection",
"Description": "Test Connection",
"ConnectionType": "JDBC",
"ConnectionProperties": {"key": "value"},
"PhysicalConnectionRequirements": {
"SubnetId": subnet_id,
"SecurityGroupIdList": [],
"AvailabilityZone": "us-east-1a",
},
}
client.create_connection(ConnectionInput=connection_input)
connection = client.get_connection(Name="test-connection")["Connection"]
assert connection["Name"] == "test-connection"
assert connection["Status"] == "READY"
assert "PhysicalConnectionRequirements" in connection_input
assert connection["PhysicalConnectionRequirements"]["SubnetId"] == subnet_id

# Test not found
with pytest.raises(client.exceptions.EntityNotFoundException):
client.get_connection(Name="nonexistent")


@mock_aws
def test_get_connections():
client = boto3.client("glue", region_name="ap-southeast-1")
for i in range(3):
subnet_id = f"subnet-1234567890abcdef{i}"
connection_input = {
"Name": f"test-connection-{i}",
"Description": "Test Connection",
"ConnectionType": "JDBC",
"ConnectionProperties": {"key": "value"},
"PhysicalConnectionRequirements": {
"SubnetId": subnet_id,
"SecurityGroupIdList": [],
"AvailabilityZone": "us-east-1a",
},
}
client.create_connection(ConnectionInput=connection_input)

connections = client.get_connections()["ConnectionList"]
assert len(connections) == 3
assert connections[0]["Name"] == "test-connection-0"
assert connections[1]["Name"] == "test-connection-1"
assert connections[2]["Name"] == "test-connection-2"
Loading