Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func resourceNATGatewayPublicIpAssociationCreate(d *pluginsdk.ResourceData, meta
return err
}

locks.ByName(natGatewayId.NatGatewayName, natGatewayResourceName)
defer locks.UnlockByName(natGatewayId.NatGatewayName, natGatewayResourceName)
locks.ByID(natGatewayId.ID())
defer locks.UnlockByID(natGatewayId.ID())

natGateway, err := client.Get(ctx, *natGatewayId, natgateways.DefaultGetOperationOptions())
if err != nil {
Expand Down Expand Up @@ -183,8 +183,8 @@ func resourceNATGatewayPublicIpAssociationDelete(d *pluginsdk.ResourceData, meta
return err
}

locks.ByName(id.First.NatGatewayName, natGatewayResourceName)
defer locks.UnlockByName(id.First.NatGatewayName, natGatewayResourceName)
locks.ByID(id.First.ID())
defer locks.UnlockByID(id.First.ID())

natGateway, err := client.Get(ctx, *id.First, natgateways.DefaultGetOperationOptions())
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ func resourceNATGatewayPublicIpPrefixAssociationCreate(d *pluginsdk.ResourceData
return err
}

locks.ByName(natGatewayId.NatGatewayName, natGatewayResourceName)
defer locks.UnlockByName(natGatewayId.NatGatewayName, natGatewayResourceName)
locks.ByID(natGatewayId.ID())
defer locks.UnlockByID(natGatewayId.ID())

natGateway, err := client.Get(ctx, *natGatewayId, natgateways.DefaultGetOperationOptions())
if err != nil {
Expand Down Expand Up @@ -184,8 +184,8 @@ func resourceNATGatewayPublicIpPrefixAssociationDelete(d *pluginsdk.ResourceData
return err
}

locks.ByName(id.First.NatGatewayName, natGatewayResourceName)
defer locks.UnlockByName(id.First.NatGatewayName, natGatewayResourceName)
locks.ByID(id.First.ID())
defer locks.UnlockByID(id.First.ID())

natGateway, err := client.Get(ctx, *id.First, natgateways.DefaultGetOperationOptions())
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package network

import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/hashicorp/go-azure-helpers/lang/pointer"
"github.com/hashicorp/go-azure-helpers/lang/response"
"github.com/hashicorp/go-azure-helpers/resourcemanager/commonids"
"github.com/hashicorp/go-azure-helpers/resourcemanager/commonschema"
"github.com/hashicorp/go-azure-sdk/resource-manager/network/2025-01-01/natgateways"
"github.com/hashicorp/go-azure-sdk/resource-manager/network/2025-01-01/publicipaddresses"
"github.com/hashicorp/terraform-provider-azurerm/internal/locks"
"github.com/hashicorp/terraform-provider-azurerm/internal/sdk"
"github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk"
)

type NatGatewayPublicIPv6AssociationResource struct{}

var _ sdk.ResourceWithCustomizeDiff = NatGatewayPublicIPv6AssociationResource{}

type NatGatewayPublicIPv6AssociationModel struct {
NatGatewayId string `tfschema:"nat_gateway_id"`
PublicIpAddressId string `tfschema:"public_ip_address_id"`
}

func (r NatGatewayPublicIPv6AssociationResource) Arguments() map[string]*pluginsdk.Schema {
return map[string]*pluginsdk.Schema{
"nat_gateway_id": commonschema.ResourceIDReferenceRequiredForceNew(&natgateways.NatGatewayId{}),

"public_ip_address_id": commonschema.ResourceIDReferenceRequiredForceNew(&commonids.PublicIPAddressId{}),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a validator in CustomizeDiff to ensure the ID is an IPv6 address, assuming the Get API is a lightweight call? That way we’ll catch the error during the plan stage if an IPv4 ID is supplied.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

}
}

func (r NatGatewayPublicIPv6AssociationResource) Attributes() map[string]*pluginsdk.Schema {
return map[string]*pluginsdk.Schema{}
}

func (r NatGatewayPublicIPv6AssociationResource) ModelObject() interface{} {
return &NatGatewayPublicIPv6AssociationModel{}
}

func (r NatGatewayPublicIPv6AssociationResource) ResourceType() string {
return "azurerm_nat_gateway_public_ipv6_association"
}

func (r NatGatewayPublicIPv6AssociationResource) CustomizeDiff() sdk.ResourceFunc {
return sdk.ResourceFunc{
Timeout: 5 * time.Minute,
Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error {
rawNatGatewayId := metadata.ResourceDiff.GetRawConfig().AsValueMap()["nat_gateway_id"]
if rawNatGatewayId.IsNull() || !rawNatGatewayId.IsKnown() {
return nil
}

rawPublicIPAddressId := metadata.ResourceDiff.GetRawConfig().AsValueMap()["public_ip_address_id"]
if rawPublicIPAddressId.IsNull() || !rawPublicIPAddressId.IsKnown() {
return nil
}

var model NatGatewayPublicIPv6AssociationModel
if err := metadata.DecodeDiff(&model); err != nil {
return fmt.Errorf("decoding: %+v", err)
}

natGatewayId, err := natgateways.ParseNatGatewayID(model.NatGatewayId)
if err != nil {
return err
}

natGatewayResp, err := metadata.Client.Network.NatGateways.Get(ctx, *natGatewayId, natgateways.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(natGatewayResp.HttpResponse) {
return nil
}
return fmt.Errorf("retrieving %s: %+v", *natGatewayId, err)
}

if natGatewayResp.Model == nil {
return fmt.Errorf("retrieving %s: `model` was nil", *natGatewayId)
}

natGatewaySku := pointer.From(pointer.From(natGatewayResp.Model.Sku).Name)
if natGatewaySku == natgateways.NatGatewaySkuNameStandard {
return errors.New("`nat_gateway_id` must reference a NAT Gateway with SKU `StandardV2`")
}

publicIPAddressId, err := commonids.ParsePublicIPAddressID(model.PublicIpAddressId)
if err != nil {
return err
}

resp, err := metadata.Client.Network.PublicIPAddresses.Get(ctx, *publicIPAddressId, publicipaddresses.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(resp.HttpResponse) {
return nil
}
return fmt.Errorf("retrieving %s: %+v", *publicIPAddressId, err)
}

if resp.Model == nil || resp.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", *publicIPAddressId)
}

if natGatewaySku == natgateways.NatGatewaySkuNameStandardVTwo && (pointer.From(pointer.From(resp.Model.Sku).Name) != publicipaddresses.PublicIPAddressSkuNameStandardVTwo || pointer.From(resp.Model.Properties.PublicIPAddressVersion) != publicipaddresses.IPVersionIPvSix) {
return errors.New("`public_ip_address_id` must reference an `IPv6` Public IP Address with SKU `StandardV2`")
}

return nil
},
}
}

func (r NatGatewayPublicIPv6AssociationResource) Create() sdk.ResourceFunc {
return sdk.ResourceFunc{
Timeout: 30 * time.Minute,
Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error {
client := metadata.Client.Network.NatGateways

var state NatGatewayPublicIPv6AssociationModel
if err := metadata.Decode(&state); err != nil {
return fmt.Errorf("decoding: %+v", err)
}

publicIpAddressId, err := commonids.ParsePublicIPAddressID(state.PublicIpAddressId)
if err != nil {
return err
}

natGatewayId, err := natgateways.ParseNatGatewayID(state.NatGatewayId)
if err != nil {
return err
}

locks.ByID(natGatewayId.ID())
defer locks.UnlockByID(natGatewayId.ID())
Comment on lines +141 to +142
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update the locks of azurerm_nat_gateway_public_ip_prefix_association and azurerm_nat_gateway_public_ip_association to use ID instead locks.ByName if we use ID here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


natGateway, err := client.Get(ctx, *natGatewayId, natgateways.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(natGateway.HttpResponse) {
return fmt.Errorf("%s was not found", *natGatewayId)
}
return fmt.Errorf("retrieving %s: %+v", *natGatewayId, err)
}

if natGateway.Model == nil {
return fmt.Errorf("retrieving %s: `model` was nil", *natGatewayId)
}
if natGateway.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `properties` was nil", *natGatewayId)
}

id := commonids.NewCompositeResourceID(natGatewayId, publicIpAddressId)

existingIPs := pointer.From(natGateway.Model.Properties.PublicIPAddressesV6)
for _, existingPublicIPAddress := range existingIPs {
if strings.EqualFold(pointer.From(existingPublicIPAddress.Id), publicIpAddressId.ID()) {
return metadata.ResourceRequiresImport(r.ResourceType(), id)
}
}

existingIPs = append(existingIPs, natgateways.SubResource{
Id: pointer.To(state.PublicIpAddressId),
})
natGateway.Model.Properties.PublicIPAddressesV6 = pointer.To(existingIPs)

if err := client.CreateOrUpdateThenPoll(ctx, *natGatewayId, *natGateway.Model); err != nil {
return fmt.Errorf("creating %s: %+v", id, err)
}

metadata.SetID(id)
return nil
},
}
}

func (r NatGatewayPublicIPv6AssociationResource) Read() sdk.ResourceFunc {
return sdk.ResourceFunc{
Timeout: 5 * time.Minute,
Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error {
client := metadata.Client.Network.NatGateways

id, err := commonids.ParseCompositeResourceID(metadata.ResourceData.Id(), &natgateways.NatGatewayId{}, &commonids.PublicIPAddressId{})
if err != nil {
return err
}

natGateway, err := client.Get(ctx, *id.First, natgateways.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(natGateway.HttpResponse) {
return metadata.MarkAsGone(id)
}
return fmt.Errorf("retrieving %s: %+v", *id.First, err)
}

publicIPAddressFound := false
if natGateway.Model == nil || natGateway.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `model` or `properties` was nil", id.First)
}

for _, pip := range pointer.From(natGateway.Model.Properties.PublicIPAddressesV6) {
if strings.EqualFold(pointer.From(pip.Id), id.Second.ID()) {
publicIPAddressFound = true
break
}
}

if !publicIPAddressFound {
return metadata.MarkAsGone(id)
}

return metadata.Encode(&NatGatewayPublicIPv6AssociationModel{
NatGatewayId: id.First.ID(),
PublicIpAddressId: id.Second.ID(),
})
},
}
}

func (r NatGatewayPublicIPv6AssociationResource) Delete() sdk.ResourceFunc {
return sdk.ResourceFunc{
Timeout: 30 * time.Minute,
Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error {
client := metadata.Client.Network.NatGateways

id, err := commonids.ParseCompositeResourceID(metadata.ResourceData.Id(), &natgateways.NatGatewayId{}, &commonids.PublicIPAddressId{})
if err != nil {
return err
}

locks.ByID(id.First.ID())
defer locks.UnlockByID(id.First.ID())

natGateway, err := client.Get(ctx, *id.First, natgateways.DefaultGetOperationOptions())
if err != nil {
if response.WasNotFound(natGateway.HttpResponse) {
return fmt.Errorf("%s was not found", *id.First)
}
return fmt.Errorf("retrieving %s: %+v", *id.First, err)
}

if natGateway.Model == nil {
return fmt.Errorf("retrieving %s: `model` was nil", *id.First)
}
if natGateway.Model.Properties == nil {
return fmt.Errorf("retrieving %s: `properties` was nil", *id.First)
}

publicIpAddressesV6 := make([]natgateways.SubResource, 0)
needToDelete := false
for _, publicIPAddress := range pointer.From(natGateway.Model.Properties.PublicIPAddressesV6) {
if !strings.EqualFold(pointer.From(publicIPAddress.Id), id.Second.ID()) {
publicIpAddressesV6 = append(publicIpAddressesV6, publicIPAddress)
} else {
needToDelete = true
}
}

if !needToDelete {
return nil
}
natGateway.Model.Properties.PublicIPAddressesV6 = pointer.To(publicIpAddressesV6)

if err := client.CreateOrUpdateThenPoll(ctx, *id.First, *natGateway.Model); err != nil {
return fmt.Errorf("deleting %s: %+v", id, err)
}

return nil
},
}
}

func (r NatGatewayPublicIPv6AssociationResource) IDValidationFunc() pluginsdk.SchemaValidateFunc {
return func(input interface{}, key string) (warnings []string, errors []error) {
if _, err := commonids.ParseCompositeResourceID(input.(string), &natgateways.NatGatewayId{}, &commonids.PublicIPAddressId{}); err != nil {
errors = append(errors, err)
}
return warnings, errors
}
}
Loading
Loading