@@ -704,10 +704,13 @@ def process(self):
704704 return
705705
706706 desired_firewall_ips = set ()
707+ fw_chains_created = set ()
707708 if self .config .is_vpc () and self .config .is_vpc_firewall_enabled ():
708709 desired_firewall_ips = self ._get_desired_vpc_firewall_ips ()
709-
710- fw_chains_created = set ()
710+ # Pre-create FIREWALL chains for ALL public IPs that have any active rule
711+ # (static NAT, port forwarding, LB, or explicit firewall rule) so that the
712+ # default DROP is always in place even before any explicit firewall rule exists.
713+ self ._ensure_vpc_firewall_chains (desired_firewall_ips , fw_chains_created )
711714
712715 for item in self .dbag :
713716 if item == "id" :
@@ -717,8 +720,8 @@ def process(self):
717720 else :
718721 if self .config .is_vpc () and self .dbag [item ].get ("purpose" ) == "Firewall" and not self .config .is_vpc_firewall_enabled ():
719722 continue
720- # For VPC firewall rules, create the PREROUTING jump and chain skeleton
721- # once per public IP before adding the individual rule
723+ # Chain skeleton is already ensured by the pre-creation pass above;
724+ # _ensure_vpc_firewall_chains is idempotent (skips IPs in fw_chains_created).
722725 if self .config .is_vpc () and self .config .is_vpc_firewall_enabled () and self .dbag [item ].get ("purpose" ) == "Firewall" :
723726 src_ip = self .dbag [item ].get ("src_ip" )
724727 self ._ensure_vpc_firewall_chains ([src_ip ], fw_chains_created )
@@ -728,19 +731,66 @@ def process(self):
728731 self ._cleanup_removed_vpc_firewall_chains (desired_firewall_ips )
729732
730733 def _get_desired_vpc_firewall_ips (self ):
731- desired_firewall_ips = set ()
734+ """
735+ Collect the full set of public IPs that should have a FIREWALL mangle chain
736+ in a VPC with firewall capability. This includes IPs from explicit firewall
737+ rules, forwarding/static-NAT rules, and load-balancer rules.
738+ """
732739 if not self .config .is_vpc ():
733- return desired_firewall_ips
740+ return set ()
741+
742+ ips = set ()
743+ ips .update (self ._get_firewall_rule_ips ())
744+ ips .update (self ._get_forwarding_rule_ips ())
745+ ips .update (self ._get_loadbalancer_ips ())
746+ return ips
734747
748+ def _get_firewall_rule_ips (self ):
749+ """Return public IPs that have explicit firewall rules in this data bag."""
750+ ips = set ()
735751 for item in self .dbag :
736752 if item == "id" :
737753 continue
738754 rule = self .dbag [item ]
739755 if rule .get ("purpose" ) == "Firewall" :
740756 src_ip = rule .get ("src_ip" )
741757 if src_ip :
742- desired_firewall_ips .add (src_ip )
743- return desired_firewall_ips
758+ ips .add (src_ip )
759+ return ips
760+
761+ def _get_forwarding_rule_ips (self ):
762+ """
763+ Return public IPs from the forwardingrules bag (static NAT and port forwarding).
764+ That bag is keyed by public IP, so each key (other than 'id') is a public IP.
765+ """
766+ ips = set ()
767+ try :
768+ fwd_bag = CsDataBag ("forwardingrules" , self .config )
769+ for public_ip in fwd_bag .get_bag ():
770+ if public_ip == "id" :
771+ continue
772+ ips .add (public_ip )
773+ except Exception as e :
774+ logging .debug ("Could not load forwardingrules for VPC firewall chain collection: %s" , e )
775+ return ips
776+
777+ def _get_loadbalancer_ips (self ):
778+ """
779+ Return public IPs from the loadbalancer bag.
780+ add_rules entries are formatted as 'ip:port', so the IP is the first segment.
781+ """
782+ ips = set ()
783+ try :
784+ lb_bag = CsDataBag ("loadbalancer" , self .config )
785+ lb_data = lb_bag .get_bag ()
786+ if "config" in lb_data and lb_data ["config" ]:
787+ for rule_str in lb_data ["config" ][0 ].get ("add_rules" , []):
788+ ip = rule_str .split (":" )[0 ]
789+ if ip :
790+ ips .add (ip )
791+ except Exception as e :
792+ logging .debug ("Could not load loadbalancer for VPC firewall chain collection: %s" , e )
793+ return ips
744794
745795 def _ensure_vpc_firewall_chains (self , source_ips , fw_chains_created ):
746796 fw = self .config .get_fw ()
0 commit comments