Add Go wrappers for Get/SetVirtualDiskInformation win32 APIs

SetVirtualDiskInformation API is required for running confidential windows containers. The
VHDs used for starting confidential pods/UVMs need to have a specific disk
identifier. These newly added APIs will be used when preparing the VHDs for confidential
pods.

Signed-off-by: Amit <[email protected]>
diff --git a/vhd/vhd.go b/vhd/vhd.go
index c0a22d6..7305cb8 100644
--- a/vhd/vhd.go
+++ b/vhd/vhd.go
@@ -3,8 +3,12 @@
 package vhd
 
 import (
+	"bytes"
+	"encoding/binary"
 	"fmt"
+	"strings"
 	"syscall"
+	"unsafe"
 
 	"github.com/Microsoft/go-winio/pkg/guid"
 	"golang.org/x/sys/windows"
@@ -17,6 +21,8 @@
 //sys attachVirtualDisk(handle syscall.Handle, securityDescriptor *uintptr, attachVirtualDiskFlag uint32, providerSpecificFlags uint32, parameters *AttachVirtualDiskParameters, overlapped *syscall.Overlapped) (win32err error) = virtdisk.AttachVirtualDisk
 //sys detachVirtualDisk(handle syscall.Handle, detachVirtualDiskFlags uint32, providerSpecificFlags uint32) (win32err error) = virtdisk.DetachVirtualDisk
 //sys getVirtualDiskPhysicalPath(handle syscall.Handle, diskPathSizeInBytes *uint32, buffer *uint16) (win32err error) = virtdisk.GetVirtualDiskPhysicalPath
+//sys getVirtualDiskInformation(handle syscall.Handle, bufferSize *uint32, info *virtualDiskInfo, sizeUsed *uint32) (win32err error) = virtdisk.GetVirtualDiskInformation
+//sys setVirtualDiskInformation(handle syscall.Handle, info *virtualDiskInfo) (win32err error) = virtdisk.SetVirtualDiskInformation
 
 type (
 	CreateVirtualDiskFlag uint32
@@ -85,6 +91,20 @@
 	Version2 AttachVersion2
 }
 
+// `virtualDiskInfo` struct is used to represent both GET_VIRTUAL_DISK_INFO
+// (https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/ns-virtdisk-get_virtual_disk_info)
+// and SET_VIRTUAL_DISK_INFO
+// (https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/ns-virtdisk-set_virtual_disk_info)
+// win32 types. Both of these win32 types have the same size and a very similar
+// structure. These types use tagged unions which aren't directly supported in Go, so we
+// keep this type unexported, and provide a cleaner interface to our callers by parsing
+// the data buffer for the right type.
+type virtualDiskInfo struct {
+	version uint32
+	_       [4]byte  // padding
+	data    [24]byte // union of various types
+}
+
 const (
 	//revive:disable-next-line:var-naming ALL_CAPS
 	VIRTUAL_STORAGE_TYPE_DEVICE_VHDX = 0x3
@@ -142,6 +162,34 @@
 
 	// Flags for detaching a VHD.
 	DetachVirtualDiskFlagNone DetachVirtualDiskFlag = 0x0
+
+	// Flags for setting information about a VHD - these should remain unexported as we provide APIs to directly get/set a particular field.
+	setVirtualDiskInfoUnspecified         uint32 = 0x0
+	setVirtualDiskInfoParentPath          uint32 = 0x1
+	setVirtualDiskInfoIdentifier          uint32 = 0x2
+	setVirtualDiskInfoParentPathWithDepth uint32 = 0x3
+	setVirtualDiskInfoPhysicalSectorSize  uint32 = 0x4
+	setVirtualDiskInfoVirtualDiskID       uint32 = 0x5
+	setVirtualDiskInfoChangeTrackingState uint32 = 0x6
+	setVirtualDiskInfoParentLocator       uint32 = 0x7
+
+	// Flags for getting information about a VHD - these should remain unexported as we provide APIs to directly get/set a particular field.
+	getVirtualDiskInfoUnspecified             uint32 = 0x0
+	getVirtualDiskInfoSize                    uint32 = 0x1
+	getVirtualDiskInfoIdentifier              uint32 = 0x2
+	getVirtualDiskInfoParentLocation          uint32 = 0x3
+	getVirtualDiskInfoParentIdentifier        uint32 = 0x4
+	getVirtualDiskInfoParentTimestamp         uint32 = 0x5
+	getVirtualDiskInfoVirtualStorageType      uint32 = 0x6
+	getVirtualDiskInfoProviderSubtype         uint32 = 0x7
+	getVirtualDiskInfoIs4kAligned             uint32 = 0x8
+	getVirtualDiskInfoPhysicalDisk            uint32 = 0x9
+	getVirtualDiskInfoVHDPhysicalSectorSize   uint32 = 0xA
+	getVirtualDiskInfoSmallestSafeVirtualSize uint32 = 0xB
+	getVirtualDiskInfoFragmentation           uint32 = 0xC
+	getVirtualDiskInfoIsLoaded                uint32 = 0xD
+	getVirtualDiskInfoVirtualDiskID           uint32 = 0xE
+	getVirtualDiskInfoChangeTrackingState     uint32 = 0xF
 )
 
 // CreateVhdx is a helper function to create a simple vhdx file at the given path using
@@ -374,3 +422,60 @@
 	}
 	return nil
 }
+
+// SetVirtualDiskIdentifier sets the virtual disk identifier for the specified virtual disk.
+func SetVirtualDiskIdentifier(vhdPath string, identifier guid.GUID) error {
+	handle, err := OpenVirtualDisk(vhdPath, VirtualDiskAccessNone, OpenVirtualDiskFlagNone)
+	if err != nil {
+		return fmt.Errorf("failed to open %s: %w", vhdPath, err)
+	}
+	defer syscall.Close(handle)
+
+	info := &virtualDiskInfo{
+		version: setVirtualDiskInfoIdentifier,
+	}
+	if strings.HasSuffix(vhdPath, ".vhdx") {
+		// VHDx requires a different version to set disk id
+		info.version = setVirtualDiskInfoVirtualDiskID
+	}
+
+	if _, err := binary.Encode(info.data[:], binary.LittleEndian, identifier); err != nil {
+		return fmt.Errorf("failed to serialize virtual disk identifier: %w", err)
+	}
+
+	if err := setVirtualDiskInformation(handle, info); err != nil {
+		return fmt.Errorf("failed to set virtual disk identifier: %w", err)
+	}
+	return nil
+}
+
+// GetVirtualDiskIdentifier retrieves the virtual disk identifier for the specified virtual disk.
+func GetVirtualDiskIdentifier(vhdPath string) (guid.GUID, error) {
+	handle, err := OpenVirtualDisk(vhdPath, VirtualDiskAccessNone, OpenVirtualDiskFlagNone)
+	if err != nil {
+		return guid.GUID{}, fmt.Errorf("failed to open %s: %w", vhdPath, err)
+	}
+	defer syscall.Close(handle)
+
+	info := &virtualDiskInfo{
+		version: getVirtualDiskInfoIdentifier,
+	}
+	if strings.HasSuffix(vhdPath, ".vhdx") {
+		// VHDx requires a different version to get disk id
+		info.version = getVirtualDiskInfoVirtualDiskID
+	}
+
+	var sizeUsed uint32
+	bufferSize := uint32(unsafe.Sizeof(*info))
+	if err := getVirtualDiskInformation(handle, &bufferSize, info, &sizeUsed); err != nil {
+		return guid.GUID{}, fmt.Errorf("failed to get virtual disk identifier: %w", err)
+	}
+
+	// Parse the response
+	id := &guid.GUID{}
+	reader := bytes.NewReader(info.data[:])
+	if err := binary.Read(reader, binary.LittleEndian, id); err != nil {
+		return guid.GUID{}, fmt.Errorf("failed to parse virtual disk identifier: %w", err)
+	}
+	return *id, nil
+}
diff --git a/vhd/vhd_test.go b/vhd/vhd_test.go
new file mode 100644
index 0000000..4124af9
--- /dev/null
+++ b/vhd/vhd_test.go
@@ -0,0 +1,59 @@
+//go:build windows
+
+package vhd
+
+import (
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/Microsoft/go-winio/pkg/guid"
+)
+
+func TestVirtualDiskIdentifier(t *testing.T) {
+	// Create a temporary directory for the test
+	tempDir := t.TempDir()
+	// TODO(ambarve): We should add a test for VHD too, but our current create VHD API
+	// seem to only work for VHDX.
+	vhdPath := filepath.Join(tempDir, "test.vhdx")
+
+	// Create the virtual disk
+	if err := CreateVhdx(vhdPath, 1, 1); err != nil { // 1GB, 1MB block size
+		t.Fatalf("failed to create virtual disk: %s", err)
+	}
+	defer os.Remove(vhdPath)
+
+	// Get the initial identifier
+	initialID, err := GetVirtualDiskIdentifier(vhdPath)
+	if err != nil {
+		t.Fatalf("failed to get initial virtual disk identifier: %s", err)
+	}
+	t.Logf("initial identifier: %s", initialID.String())
+
+	// Create a new GUID to set
+	newID := guid.GUID{
+		Data1: 0x12345678,
+		Data2: 0x1234,
+		Data3: 0x5678,
+		Data4: [8]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0},
+	}
+	t.Logf("setting new identifier: %s", newID.String())
+
+	// Set the new identifier
+	if err := SetVirtualDiskIdentifier(vhdPath, newID); err != nil {
+		t.Fatalf("failed to set virtual disk identifier: %s", err)
+	}
+
+	// Get the identifier again to verify it was set correctly
+	retrievedID, err := GetVirtualDiskIdentifier(vhdPath)
+	if err != nil {
+		t.Fatalf("failed to get virtual disk identifier after setting: %s", err)
+	}
+	t.Logf("retrieved identifier: %s", retrievedID.String())
+
+	// Verify the retrieved ID matches the one we set
+	if retrievedID != newID {
+		t.Errorf("retrieved identifier does not match set identifier.\nExpected: %s\nGot: %s",
+			newID.String(), retrievedID.String())
+	}
+}
diff --git a/vhd/zvhd_windows.go b/vhd/zvhd_windows.go
index 95c0407..e9d202e 100644
--- a/vhd/zvhd_windows.go
+++ b/vhd/zvhd_windows.go
@@ -42,8 +42,10 @@
 	procAttachVirtualDisk          = modvirtdisk.NewProc("AttachVirtualDisk")
 	procCreateVirtualDisk          = modvirtdisk.NewProc("CreateVirtualDisk")
 	procDetachVirtualDisk          = modvirtdisk.NewProc("DetachVirtualDisk")
+	procGetVirtualDiskInformation  = modvirtdisk.NewProc("GetVirtualDiskInformation")
 	procGetVirtualDiskPhysicalPath = modvirtdisk.NewProc("GetVirtualDiskPhysicalPath")
 	procOpenVirtualDisk            = modvirtdisk.NewProc("OpenVirtualDisk")
+	procSetVirtualDiskInformation  = modvirtdisk.NewProc("SetVirtualDiskInformation")
 )
 
 func attachVirtualDisk(handle syscall.Handle, securityDescriptor *uintptr, attachVirtualDiskFlag uint32, providerSpecificFlags uint32, parameters *AttachVirtualDiskParameters, overlapped *syscall.Overlapped) (win32err error) {
@@ -79,6 +81,14 @@
 	return
 }
 
+func getVirtualDiskInformation(handle syscall.Handle, bufferSize *uint32, info *virtualDiskInfo, sizeUsed *uint32) (win32err error) {
+	r0, _, _ := syscall.SyscallN(procGetVirtualDiskInformation.Addr(), uintptr(handle), uintptr(unsafe.Pointer(bufferSize)), uintptr(unsafe.Pointer(info)), uintptr(unsafe.Pointer(sizeUsed)))
+	if r0 != 0 {
+		win32err = syscall.Errno(r0)
+	}
+	return
+}
+
 func getVirtualDiskPhysicalPath(handle syscall.Handle, diskPathSizeInBytes *uint32, buffer *uint16) (win32err error) {
 	r0, _, _ := syscall.SyscallN(procGetVirtualDiskPhysicalPath.Addr(), uintptr(handle), uintptr(unsafe.Pointer(diskPathSizeInBytes)), uintptr(unsafe.Pointer(buffer)))
 	if r0 != 0 {
@@ -103,3 +113,11 @@
 	}
 	return
 }
+
+func setVirtualDiskInformation(handle syscall.Handle, info *virtualDiskInfo) (win32err error) {
+	r0, _, _ := syscall.SyscallN(procSetVirtualDiskInformation.Addr(), uintptr(handle), uintptr(unsafe.Pointer(info)))
+	if r0 != 0 {
+		win32err = syscall.Errno(r0)
+	}
+	return
+}