Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package network

import (
"context"
"errors"
"fmt"
"log"
"strings"
Expand All @@ -13,6 +15,7 @@ import (
"github.com/hashicorp/go-azure-helpers/lang/response"
"github.com/hashicorp/go-azure-helpers/resourcemanager/commonids"
"github.com/hashicorp/go-azure-sdk/resource-manager/network/2025-01-01/natgateways"
"github.com/hashicorp/go-azure-sdk/resource-manager/network/2025-01-01/publicipaddresses"
"github.com/hashicorp/terraform-provider-azurerm/helpers/tf"
"github.com/hashicorp/terraform-provider-azurerm/internal/clients"
"github.com/hashicorp/terraform-provider-azurerm/internal/locks"
Expand All @@ -26,6 +29,8 @@ func resourceNATGatewayPublicIpAssociation() *pluginsdk.Resource {
Read: resourceNATGatewayPublicIpAssociationRead,
Delete: resourceNATGatewayPublicIpAssociationDelete,

CustomizeDiff: pluginsdk.CustomizeDiffShim(resourceNATGatewayPublicIpAssociationCustomizeDiff),

Importer: pluginsdk.ImporterValidatingResourceId(func(id string) error {
_, err := commonids.ParseCompositeResourceID(id, &natgateways.NatGatewayId{}, &commonids.PublicIPAddressId{})
return err
Expand Down Expand Up @@ -55,6 +60,56 @@ func resourceNATGatewayPublicIpAssociation() *pluginsdk.Resource {
}
}

func resourceNATGatewayPublicIpAssociationCustomizeDiff(ctx context.Context, d *pluginsdk.ResourceDiff, meta any) error {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we are supporting both ipv4 and ipv6 now, this CustomizeDiff is not necessary anymore, we can error out in Create if configured id not eixsts.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()

rawNatGatewayId := d.GetRawConfig().AsValueMap()["nat_gateway_id"]
if rawNatGatewayId.IsNull() || !rawNatGatewayId.IsKnown() {
return nil
}

rawPublicIPAddressId := d.GetRawConfig().AsValueMap()["public_ip_address_id"]
if rawPublicIPAddressId.IsNull() || !rawPublicIPAddressId.IsKnown() {
return nil
}

natGatewayId, err := natgateways.ParseNatGatewayID(d.Get("nat_gateway_id").(string))
if err != nil {
return err
}

publicIpAddressId, err := commonids.ParsePublicIPAddressID(d.Get("public_ip_address_id").(string))
if err != nil {
return err
}

client := meta.(*clients.Client)
natGateway, err := client.Network.NatGateways.Get(ctx, *natGatewayId, natgateways.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(natGateway.HttpResponse) {
return nil
}
return fmt.Errorf("retrieving %s: %+v", natGatewayId, err)
}
if natGateway.Model == nil {
return fmt.Errorf("retrieving %s: `model` was nil", natGatewayId)
}

publicIPAddress, err := client.Network.PublicIPAddresses.Get(ctx, *publicIpAddressId, publicipaddresses.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(publicIPAddress.HttpResponse) {
return nil
}
return fmt.Errorf("retrieving %s: %+v", publicIpAddressId, err)
}
if publicIPAddress.Model == nil || publicIPAddress.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", publicIpAddressId)
}

return validateNATGatewayPublicIpAssociation(natGateway.Model, publicIPAddress.Model)
}

func resourceNATGatewayPublicIpAssociationCreate(d *pluginsdk.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).Network.NatGateways
ctx, cancel := timeouts.ForCreate(meta.(*clients.Client).StopContext, d)
Expand All @@ -70,8 +125,8 @@ func resourceNATGatewayPublicIpAssociationCreate(d *pluginsdk.ResourceData, meta
return err
}

locks.ByName(natGatewayId.NatGatewayName, natGatewayResourceName)
defer locks.UnlockByName(natGatewayId.NatGatewayName, natGatewayResourceName)
locks.ByID(natGatewayId.ID())
defer locks.UnlockByID(natGatewayId.ID())

natGateway, err := client.Get(ctx, *natGatewayId, natgateways.DefaultGetOperationOptions())
if err != nil {
Expand All @@ -88,27 +143,42 @@ func resourceNATGatewayPublicIpAssociationCreate(d *pluginsdk.ResourceData, meta
return fmt.Errorf("retrieving %s: `properties` was nil", natGatewayId)
}

id := commonids.NewCompositeResourceID(natGatewayId, publicIpAddressId)
publicIPAddress, err := meta.(*clients.Client).Network.PublicIPAddresses.Get(ctx, *publicIpAddressId, publicipaddresses.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(publicIPAddress.HttpResponse) {
return fmt.Errorf("%s was not found", publicIpAddressId)
}
return fmt.Errorf("retrieving %s: %+v", publicIpAddressId, err)
}
if publicIPAddress.Model == nil || publicIPAddress.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", publicIpAddressId)
}

publicIpAddresses := make([]natgateways.SubResource, 0)
if natGateway.Model.Properties.PublicIPAddresses != nil {
for _, existingPublicIPAddress := range *natGateway.Model.Properties.PublicIPAddresses {
if existingPublicIPAddress.Id == nil {
continue
}
isIPv6 := natGatewayPublicIpAssociationIsIPv6(publicIPAddress.Model)
if err := validateNATGatewayPublicIpAssociation(natGateway.Model, publicIPAddress.Model); err != nil {
return err
}

if strings.EqualFold(*existingPublicIPAddress.Id, publicIpAddressId.ID()) {
return tf.ImportAsExistsError("azurerm_nat_gateway_public_ip_association", id.ID())
}
id := commonids.NewCompositeResourceID(natGatewayId, publicIpAddressId)

publicIpAddresses = append(publicIpAddresses, existingPublicIPAddress)
publicIpAddresses := pointer.From(natGateway.Model.Properties.PublicIPAddresses)
if isIPv6 {
publicIpAddresses = pointer.From(natGateway.Model.Properties.PublicIPAddressesV6)
}
for _, existingPublicIPAddress := range publicIpAddresses {
if strings.EqualFold(pointer.From(existingPublicIPAddress.Id), publicIpAddressId.ID()) {
return tf.ImportAsExistsError("azurerm_nat_gateway_public_ip_association", id.ID())
}
}

publicIpAddresses = append(publicIpAddresses, natgateways.SubResource{
Id: pointer.To(publicIpAddressId.ID()),
})
natGateway.Model.Properties.PublicIPAddresses = &publicIpAddresses
if isIPv6 {
natGateway.Model.Properties.PublicIPAddressesV6 = pointer.To(publicIpAddresses)
} else {
natGateway.Model.Properties.PublicIPAddresses = pointer.To(publicIpAddresses)
}
Comment thread
sreallymatt marked this conversation as resolved.

if err := client.CreateOrUpdateThenPoll(ctx, *natGatewayId, *natGateway.Model); err != nil {
return fmt.Errorf("updating %s: %+v", natGatewayId, err)
Expand Down Expand Up @@ -139,32 +209,14 @@ func resourceNATGatewayPublicIpAssociationRead(d *pluginsdk.ResourceData, meta i
return fmt.Errorf("retrieving %s: %+v", id.First, err)
}

if model := natGateway.Model; model != nil {
if props := model.Properties; props != nil {
if props.PublicIPAddresses == nil {
log.Printf("[DEBUG] %s doesn't have any Public IP's - removing from state!", id.First)
d.SetId("")
return nil
}

publicIPAddressId := ""
for _, pip := range *props.PublicIPAddresses {
if pip.Id == nil {
continue
}

if strings.EqualFold(*pip.Id, id.Second.ID()) {
publicIPAddressId = *pip.Id
break
}
}

if publicIPAddressId == "" {
log.Printf("[DEBUG] Association between %s and %s was not found - removing from state", id.First, id.Second)
d.SetId("")
return nil
}
if model := natGateway.Model; model != nil && model.Properties != nil {
if !natGatewayPublicIpAssociationExists(model.Properties, id.Second.ID()) {
log.Printf("[DEBUG] Association between %s and %s was not found - removing from state", id.First, id.Second)
d.SetId("")
return nil
}
} else {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", id.First)
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it clear:

Suggested change
if model := natGateway.Model; model != nil && model.Properties != nil {
if !natGatewayPublicIpAssociationExists(model.Properties, id.Second.ID()) {
log.Printf("[DEBUG] Association between %s and %s was not found - removing from state", id.First, id.Second)
d.SetId("")
return nil
}
} else {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", id.First)
}
if natGateway.Model == nil || natGateway.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", id.First)
}
if !natGatewayPublicIpAssociationExists(natGateway.Model.Properties, id.Second.ID()) {
log.Printf("[DEBUG] Association between %s and %s was not found - removing from state", id.First, id.Second)
d.SetId("")
return nil
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


d.Set("nat_gateway_id", id.First.ID())
Expand All @@ -183,8 +235,8 @@ func resourceNATGatewayPublicIpAssociationDelete(d *pluginsdk.ResourceData, meta
return err
}

locks.ByName(id.First.NatGatewayName, natGatewayResourceName)
defer locks.UnlockByName(id.First.NatGatewayName, natGatewayResourceName)
locks.ByID(id.First.ID())
defer locks.UnlockByID(id.First.ID())

natGateway, err := client.Get(ctx, *id.First, natgateways.DefaultGetOperationOptions())
if err != nil {
Expand All @@ -201,23 +253,83 @@ func resourceNATGatewayPublicIpAssociationDelete(d *pluginsdk.ResourceData, meta
return fmt.Errorf("retrieving %s: `properties` was nil", id.First)
}

publicIpAddresses := make([]natgateways.SubResource, 0)
if publicIPAddresses := natGateway.Model.Properties.PublicIPAddresses; publicIPAddresses != nil {
for _, publicIPAddress := range *publicIPAddresses {
if publicIPAddress.Id == nil {
continue
}
if !removeNATGatewayPublicIpAssociation(natGateway.Model.Properties, id.Second.ID()) {
return nil
}

if err := client.CreateOrUpdateThenPoll(ctx, *id.First, *natGateway.Model); err != nil {
return fmt.Errorf("deleting %s: %+v", id, err)
}

return nil
}

func validateNATGatewayPublicIpAssociation(natGateway *natgateways.NatGateway, publicIPAddress *publicipaddresses.PublicIPAddress) error {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure if this validate is solid and necessary here. if the api call is a light operation, may be we should defer to service side to validate it

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, the validation is removed

isIPv6 := natGatewayPublicIpAssociationIsIPv6(publicIPAddress)
natGatewaySku := pointer.From(pointer.From(natGateway.Sku).Name)
publicIPAddressSku := pointer.From(pointer.From(publicIPAddress.Sku).Name)

if !strings.EqualFold(*publicIPAddress.Id, id.Second.ID()) {
publicIpAddresses = append(publicIpAddresses, publicIPAddress)
}
if isIPv6 {
if natGatewaySku != natgateways.NatGatewaySkuNameStandardVTwo || publicIPAddressSku != publicipaddresses.PublicIPAddressSkuNameStandardVTwo {
return errors.New("`nat_gateway_id` must reference a NAT Gateway with SKU `StandardV2` and `public_ip_address_id` must reference an `IPv6` Public IP Address with SKU `StandardV2` when `public_ip_address_id` references an `IPv6` Public IP Address")
}
}
natGateway.Model.Properties.PublicIPAddresses = &publicIpAddresses

if err := client.CreateOrUpdateThenPoll(ctx, *id.First, *natGateway.Model); err != nil {
return fmt.Errorf("removing association between %s and %s: %+v", id.First, id.Second, err)
if natGatewaySku == natgateways.NatGatewaySkuNameStandard && publicIPAddressSku == publicipaddresses.PublicIPAddressSkuNameStandardVTwo {
return errors.New("`public_ip_address_id` must reference a Public IP Address with SKU `Standard` when `nat_gateway_id` references a NAT Gateway with SKU `Standard`")
}

if natGatewaySku == natgateways.NatGatewaySkuNameStandardVTwo && publicIPAddressSku != publicipaddresses.PublicIPAddressSkuNameStandardVTwo {
return errors.New("`public_ip_address_id` must reference a Public IP Address with SKU `StandardV2` when `nat_gateway_id` references a NAT Gateway with SKU `StandardV2`")
}

return nil
}

func natGatewayPublicIpAssociationIsIPv6(publicIPAddress *publicipaddresses.PublicIPAddress) bool {
return pointer.From(publicIPAddress.Properties.PublicIPAddressVersion) == publicipaddresses.IPVersionIPvSix
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

publicIPAddress.Properties can be nil and panic here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nil check added

}

func natGatewayPublicIpAssociationExists(properties *natgateways.NatGatewayPropertiesFormat, publicIPAddressId string) bool {
for _, publicIPAddress := range pointer.From(properties.PublicIPAddresses) {
Comment thread
sreallymatt marked this conversation as resolved.
if strings.EqualFold(pointer.From(publicIPAddress.Id), publicIPAddressId) {
return true
}
}

for _, publicIPAddress := range pointer.From(properties.PublicIPAddressesV6) {
if strings.EqualFold(pointer.From(publicIPAddress.Id), publicIPAddressId) {
return true
}
}

return false
}

func removeNATGatewayPublicIpAssociation(properties *natgateways.NatGatewayPropertiesFormat, publicIPAddressId string) bool {
Comment thread
sreallymatt marked this conversation as resolved.
removed := false

updatedIPv4Addresses := make([]natgateways.SubResource, 0)
Comment thread
sreallymatt marked this conversation as resolved.
for _, publicIPAddress := range pointer.From(properties.PublicIPAddresses) {
if strings.EqualFold(pointer.From(publicIPAddress.Id), publicIPAddressId) {
removed = true
continue
}

updatedIPv4Addresses = append(updatedIPv4Addresses, publicIPAddress)
}
properties.PublicIPAddresses = pointer.To(updatedIPv4Addresses)

updatedIPv6Addresses := make([]natgateways.SubResource, 0)
for _, publicIPAddress := range pointer.From(properties.PublicIPAddressesV6) {
if strings.EqualFold(pointer.From(publicIPAddress.Id), publicIPAddressId) {
removed = true
continue
}

updatedIPv6Addresses = append(updatedIPv6Addresses, publicIPAddress)
}
properties.PublicIPAddressesV6 = pointer.To(updatedIPv6Addresses)

return removed
}
Loading
Loading