Add Go wrappers for Get/SetVirtualDiskInformation win32 APIs
SetVirtualDiskInformation API is required for running confidential windows containers. The
VHDs used for starting confidential pods/UVMs need to have a specific disk
identifier. These newly added APIs will be used when preparing the VHDs for confidential
pods.
Signed-off-by: Amit <[email protected]>
diff --git a/vhd/vhd.go b/vhd/vhd.go
index c0a22d6..7305cb8 100644
--- a/vhd/vhd.go
+++ b/vhd/vhd.go
@@ -3,8 +3,12 @@
package vhd
import (
+ "bytes"
+ "encoding/binary"
"fmt"
+ "strings"
"syscall"
+ "unsafe"
"github.com/Microsoft/go-winio/pkg/guid"
"golang.org/x/sys/windows"
@@ -17,6 +21,8 @@
//sys attachVirtualDisk(handle syscall.Handle, securityDescriptor *uintptr, attachVirtualDiskFlag uint32, providerSpecificFlags uint32, parameters *AttachVirtualDiskParameters, overlapped *syscall.Overlapped) (win32err error) = virtdisk.AttachVirtualDisk
//sys detachVirtualDisk(handle syscall.Handle, detachVirtualDiskFlags uint32, providerSpecificFlags uint32) (win32err error) = virtdisk.DetachVirtualDisk
//sys getVirtualDiskPhysicalPath(handle syscall.Handle, diskPathSizeInBytes *uint32, buffer *uint16) (win32err error) = virtdisk.GetVirtualDiskPhysicalPath
+//sys getVirtualDiskInformation(handle syscall.Handle, bufferSize *uint32, info *virtualDiskInfo, sizeUsed *uint32) (win32err error) = virtdisk.GetVirtualDiskInformation
+//sys setVirtualDiskInformation(handle syscall.Handle, info *virtualDiskInfo) (win32err error) = virtdisk.SetVirtualDiskInformation
type (
CreateVirtualDiskFlag uint32
@@ -85,6 +91,20 @@
Version2 AttachVersion2
}
+// `virtualDiskInfo` struct is used to represent both GET_VIRTUAL_DISK_INFO
+// (https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/ns-virtdisk-get_virtual_disk_info)
+// and SET_VIRTUAL_DISK_INFO
+// (https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/ns-virtdisk-set_virtual_disk_info)
+// win32 types. Both of these win32 types have the same size and a very similar
+// structure. These types use tagged unions which aren't directly supported in Go, so we
+// keep this type unexported, and provide a cleaner interface to our callers by parsing
+// the data buffer for the right type.
+type virtualDiskInfo struct {
+ version uint32
+ _ [4]byte // padding
+ data [24]byte // union of various types
+}
+
const (
//revive:disable-next-line:var-naming ALL_CAPS
VIRTUAL_STORAGE_TYPE_DEVICE_VHDX = 0x3
@@ -142,6 +162,34 @@
// Flags for detaching a VHD.
DetachVirtualDiskFlagNone DetachVirtualDiskFlag = 0x0
+
+ // Flags for setting information about a VHD - these should remain unexported as we provide APIs to directly get/set a particular field.
+ setVirtualDiskInfoUnspecified uint32 = 0x0
+ setVirtualDiskInfoParentPath uint32 = 0x1
+ setVirtualDiskInfoIdentifier uint32 = 0x2
+ setVirtualDiskInfoParentPathWithDepth uint32 = 0x3
+ setVirtualDiskInfoPhysicalSectorSize uint32 = 0x4
+ setVirtualDiskInfoVirtualDiskID uint32 = 0x5
+ setVirtualDiskInfoChangeTrackingState uint32 = 0x6
+ setVirtualDiskInfoParentLocator uint32 = 0x7
+
+ // Flags for getting information about a VHD - these should remain unexported as we provide APIs to directly get/set a particular field.
+ getVirtualDiskInfoUnspecified uint32 = 0x0
+ getVirtualDiskInfoSize uint32 = 0x1
+ getVirtualDiskInfoIdentifier uint32 = 0x2
+ getVirtualDiskInfoParentLocation uint32 = 0x3
+ getVirtualDiskInfoParentIdentifier uint32 = 0x4
+ getVirtualDiskInfoParentTimestamp uint32 = 0x5
+ getVirtualDiskInfoVirtualStorageType uint32 = 0x6
+ getVirtualDiskInfoProviderSubtype uint32 = 0x7
+ getVirtualDiskInfoIs4kAligned uint32 = 0x8
+ getVirtualDiskInfoPhysicalDisk uint32 = 0x9
+ getVirtualDiskInfoVHDPhysicalSectorSize uint32 = 0xA
+ getVirtualDiskInfoSmallestSafeVirtualSize uint32 = 0xB
+ getVirtualDiskInfoFragmentation uint32 = 0xC
+ getVirtualDiskInfoIsLoaded uint32 = 0xD
+ getVirtualDiskInfoVirtualDiskID uint32 = 0xE
+ getVirtualDiskInfoChangeTrackingState uint32 = 0xF
)
// CreateVhdx is a helper function to create a simple vhdx file at the given path using
@@ -374,3 +422,60 @@
}
return nil
}
+
+// SetVirtualDiskIdentifier sets the virtual disk identifier for the specified virtual disk.
+func SetVirtualDiskIdentifier(vhdPath string, identifier guid.GUID) error {
+ handle, err := OpenVirtualDisk(vhdPath, VirtualDiskAccessNone, OpenVirtualDiskFlagNone)
+ if err != nil {
+ return fmt.Errorf("failed to open %s: %w", vhdPath, err)
+ }
+ defer syscall.Close(handle)
+
+ info := &virtualDiskInfo{
+ version: setVirtualDiskInfoIdentifier,
+ }
+ if strings.HasSuffix(vhdPath, ".vhdx") {
+ // VHDx requires a different version to set disk id
+ info.version = setVirtualDiskInfoVirtualDiskID
+ }
+
+ if _, err := binary.Encode(info.data[:], binary.LittleEndian, identifier); err != nil {
+ return fmt.Errorf("failed to serialize virtual disk identifier: %w", err)
+ }
+
+ if err := setVirtualDiskInformation(handle, info); err != nil {
+ return fmt.Errorf("failed to set virtual disk identifier: %w", err)
+ }
+ return nil
+}
+
+// GetVirtualDiskIdentifier retrieves the virtual disk identifier for the specified virtual disk.
+func GetVirtualDiskIdentifier(vhdPath string) (guid.GUID, error) {
+ handle, err := OpenVirtualDisk(vhdPath, VirtualDiskAccessNone, OpenVirtualDiskFlagNone)
+ if err != nil {
+ return guid.GUID{}, fmt.Errorf("failed to open %s: %w", vhdPath, err)
+ }
+ defer syscall.Close(handle)
+
+ info := &virtualDiskInfo{
+ version: getVirtualDiskInfoIdentifier,
+ }
+ if strings.HasSuffix(vhdPath, ".vhdx") {
+ // VHDx requires a different version to get disk id
+ info.version = getVirtualDiskInfoVirtualDiskID
+ }
+
+ var sizeUsed uint32
+ bufferSize := uint32(unsafe.Sizeof(*info))
+ if err := getVirtualDiskInformation(handle, &bufferSize, info, &sizeUsed); err != nil {
+ return guid.GUID{}, fmt.Errorf("failed to get virtual disk identifier: %w", err)
+ }
+
+ // Parse the response
+ id := &guid.GUID{}
+ reader := bytes.NewReader(info.data[:])
+ if err := binary.Read(reader, binary.LittleEndian, id); err != nil {
+ return guid.GUID{}, fmt.Errorf("failed to parse virtual disk identifier: %w", err)
+ }
+ return *id, nil
+}
diff --git a/vhd/vhd_test.go b/vhd/vhd_test.go
new file mode 100644
index 0000000..4124af9
--- /dev/null
+++ b/vhd/vhd_test.go
@@ -0,0 +1,59 @@
+//go:build windows
+
+package vhd
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/Microsoft/go-winio/pkg/guid"
+)
+
+func TestVirtualDiskIdentifier(t *testing.T) {
+ // Create a temporary directory for the test
+ tempDir := t.TempDir()
+ // TODO(ambarve): We should add a test for VHD too, but our current create VHD API
+ // seem to only work for VHDX.
+ vhdPath := filepath.Join(tempDir, "test.vhdx")
+
+ // Create the virtual disk
+ if err := CreateVhdx(vhdPath, 1, 1); err != nil { // 1GB, 1MB block size
+ t.Fatalf("failed to create virtual disk: %s", err)
+ }
+ defer os.Remove(vhdPath)
+
+ // Get the initial identifier
+ initialID, err := GetVirtualDiskIdentifier(vhdPath)
+ if err != nil {
+ t.Fatalf("failed to get initial virtual disk identifier: %s", err)
+ }
+ t.Logf("initial identifier: %s", initialID.String())
+
+ // Create a new GUID to set
+ newID := guid.GUID{
+ Data1: 0x12345678,
+ Data2: 0x1234,
+ Data3: 0x5678,
+ Data4: [8]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0},
+ }
+ t.Logf("setting new identifier: %s", newID.String())
+
+ // Set the new identifier
+ if err := SetVirtualDiskIdentifier(vhdPath, newID); err != nil {
+ t.Fatalf("failed to set virtual disk identifier: %s", err)
+ }
+
+ // Get the identifier again to verify it was set correctly
+ retrievedID, err := GetVirtualDiskIdentifier(vhdPath)
+ if err != nil {
+ t.Fatalf("failed to get virtual disk identifier after setting: %s", err)
+ }
+ t.Logf("retrieved identifier: %s", retrievedID.String())
+
+ // Verify the retrieved ID matches the one we set
+ if retrievedID != newID {
+ t.Errorf("retrieved identifier does not match set identifier.\nExpected: %s\nGot: %s",
+ newID.String(), retrievedID.String())
+ }
+}
diff --git a/vhd/zvhd_windows.go b/vhd/zvhd_windows.go
index 95c0407..e9d202e 100644
--- a/vhd/zvhd_windows.go
+++ b/vhd/zvhd_windows.go
@@ -42,8 +42,10 @@
procAttachVirtualDisk = modvirtdisk.NewProc("AttachVirtualDisk")
procCreateVirtualDisk = modvirtdisk.NewProc("CreateVirtualDisk")
procDetachVirtualDisk = modvirtdisk.NewProc("DetachVirtualDisk")
+ procGetVirtualDiskInformation = modvirtdisk.NewProc("GetVirtualDiskInformation")
procGetVirtualDiskPhysicalPath = modvirtdisk.NewProc("GetVirtualDiskPhysicalPath")
procOpenVirtualDisk = modvirtdisk.NewProc("OpenVirtualDisk")
+ procSetVirtualDiskInformation = modvirtdisk.NewProc("SetVirtualDiskInformation")
)
func attachVirtualDisk(handle syscall.Handle, securityDescriptor *uintptr, attachVirtualDiskFlag uint32, providerSpecificFlags uint32, parameters *AttachVirtualDiskParameters, overlapped *syscall.Overlapped) (win32err error) {
@@ -79,6 +81,14 @@
return
}
+func getVirtualDiskInformation(handle syscall.Handle, bufferSize *uint32, info *virtualDiskInfo, sizeUsed *uint32) (win32err error) {
+ r0, _, _ := syscall.SyscallN(procGetVirtualDiskInformation.Addr(), uintptr(handle), uintptr(unsafe.Pointer(bufferSize)), uintptr(unsafe.Pointer(info)), uintptr(unsafe.Pointer(sizeUsed)))
+ if r0 != 0 {
+ win32err = syscall.Errno(r0)
+ }
+ return
+}
+
func getVirtualDiskPhysicalPath(handle syscall.Handle, diskPathSizeInBytes *uint32, buffer *uint16) (win32err error) {
r0, _, _ := syscall.SyscallN(procGetVirtualDiskPhysicalPath.Addr(), uintptr(handle), uintptr(unsafe.Pointer(diskPathSizeInBytes)), uintptr(unsafe.Pointer(buffer)))
if r0 != 0 {
@@ -103,3 +113,11 @@
}
return
}
+
+func setVirtualDiskInformation(handle syscall.Handle, info *virtualDiskInfo) (win32err error) {
+ r0, _, _ := syscall.SyscallN(procSetVirtualDiskInformation.Addr(), uintptr(handle), uintptr(unsafe.Pointer(info)))
+ if r0 != 0 {
+ win32err = syscall.Errno(r0)
+ }
+ return
+}