From 3733040e18c8fa568821cb47cb4b7ecbe4f91cde Mon Sep 17 00:00:00 2001 From: SoenkeD Date: Sun, 24 Mar 2024 11:48:25 +0100 Subject: [PATCH] VPCEndpoint: add security group modification --- moto/ec2/models/vpcs.py | 21 ++++++++++++++++++++- moto/ec2/responses/vpcs.py | 27 ++++++++++++++++++++++----- tests/test_ec2/test_vpcs.py | 18 ++++++++++++++++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/moto/ec2/models/vpcs.py b/moto/ec2/models/vpcs.py index 614c9ea90bdd..759830193c7c 100644 --- a/moto/ec2/models/vpcs.py +++ b/moto/ec2/models/vpcs.py @@ -376,6 +376,8 @@ def modify( add_subnets: Optional[List[str]], add_route_tables: Optional[List[str]], remove_route_tables: Optional[List[str]], + add_security_groups: Optional[List[str]], + remove_security_groups: Optional[List[str]], ) -> None: if policy_doc: self.policy_document = policy_doc @@ -389,6 +391,14 @@ def modify( for rt_id in self.route_table_ids if rt_id not in remove_route_tables ] + if add_security_groups: + self.security_group_ids.extend(add_security_groups) + if remove_security_groups: + self.security_group_ids = [ + sg_id + for sg_id in self.security_group_ids + if sg_id not in remove_security_groups + ] def get_filter_value( self, filter_name: str, method_name: Optional[str] = None @@ -966,9 +976,18 @@ def modify_vpc_endpoint( add_subnets: Optional[List[str]], remove_route_tables: Optional[List[str]], add_route_tables: Optional[List[str]], + add_security_groups: Optional[List[str]], + remove_security_groups: Optional[List[str]], ) -> None: endpoint = self.describe_vpc_endpoints(vpc_end_point_ids=[vpc_id])[0] - endpoint.modify(policy_doc, add_subnets, add_route_tables, remove_route_tables) + endpoint.modify( + policy_doc, + add_subnets, + add_route_tables, + remove_route_tables, + add_security_groups, + remove_security_groups, + ) def delete_vpc_endpoints(self, vpce_ids: Optional[List[str]] = None) -> None: for vpce_id in vpce_ids or []: diff --git a/moto/ec2/responses/vpcs.py b/moto/ec2/responses/vpcs.py index 19de3cc2bc4e..36c775cf169e 100644 --- a/moto/ec2/responses/vpcs.py +++ b/moto/ec2/responses/vpcs.py @@ -224,12 +224,17 @@ def modify_vpc_endpoint(self) -> str: add_route_tables = self._get_multi_param("AddRouteTableId") remove_route_tables = self._get_multi_param("RemoveRouteTableId") policy_doc = self._get_param("PolicyDocument") + add_security_groups = self._get_multi_param("AddSecurityGroupId") + remove_security_groups = self._get_multi_param("RemoveSecurityGroupId") + self.ec2_backend.modify_vpc_endpoint( vpc_id=vpc_id, policy_doc=policy_doc, add_subnets=add_subnets, add_route_tables=add_route_tables, remove_route_tables=remove_route_tables, + add_security_groups=add_security_groups, + remove_security_groups=remove_security_groups, ) template = self.response_template(MODIFY_VPC_END_POINT) return template.render() @@ -251,9 +256,17 @@ def describe_vpc_endpoints(self) -> str: vpc_end_points = self.ec2_backend.describe_vpc_endpoints( vpc_end_point_ids=vpc_end_points_ids, filters=filters ) + + security_group_ids = [] + for service in vpc_end_points: + security_group_ids.extend(service.security_group_ids) + security_groups = self.ec2_backend.describe_security_groups(group_ids=security_group_ids) + template = self.response_template(DESCRIBE_VPC_ENDPOINT_RESPONSE) return template.render( - vpc_end_points=vpc_end_points, account_id=self.current_account + vpc_end_points=vpc_end_points, + account_id=self.current_account, + security_groups=security_groups, ) def delete_vpc_endpoints(self) -> str: @@ -736,10 +749,14 @@ def modify_managed_prefix_list(self) -> str: {% if vpc_end_point.security_group_ids %} {% for group_id in vpc_end_point.security_group_ids %} - - {{ group_id }} - TODO - + {% for sec_g in security_groups %} + {% if sec_g.id == group_id %} + + {{ group_id }} + {{ sec_g.name }} + + {% endif %} + {% endfor %} {% endfor %} {% endif %} diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index 221b66a311c0..d5edd200dfac 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -1170,6 +1170,24 @@ def test_modify_vpc_endpoint(): endpoint = ec2.describe_vpc_endpoints(VpcEndpointIds=[vpc_id])["VpcEndpoints"][0] assert endpoint["PolicyDocument"] == "doc" + sg_id_1 = ec2.create_security_group(GroupName="sg-1", VpcId=vpc_id, Description="sg-1")["GroupId"] + sg_id_2 = ec2.create_security_group(GroupName="sg-2", VpcId=vpc_id, Description="sg-2")["GroupId"] + ec2.modify_vpc_endpoint( + VpcEndpointId=vpc_id, + AddSecurityGroupIds=[sg_id_1, sg_id_2], + ) + endpoint = ec2.describe_vpc_endpoints(VpcEndpointIds=[vpc_id])["VpcEndpoints"][0] + group_ids = endpoint.get("Groups", []) + assert any(group.get("GroupId") == sg_id_1 and group.get("GroupName") == "sg-1" for group in group_ids) + assert any(group.get("GroupId") == sg_id_2 and group.get("GroupName") == "sg-2" for group in group_ids) + + ec2.modify_vpc_endpoint( + VpcEndpointId=vpc_id, + RemoveSecurityGroupIds=[sg_id_1], + ) + endpoint = ec2.describe_vpc_endpoints(VpcEndpointIds=[vpc_id])["VpcEndpoints"][0] + assert any(group.get("GroupId") == sg_id_2 and group.get("GroupName") == "sg-2" for group in group_ids) + @mock_aws def test_delete_vpc_end_points():