From 08971c1e5c6f776e422ea3d9e3dd21a5cf232192 Mon Sep 17 00:00:00 2001 From: Kathryn Baldauf Date: Mon, 4 Nov 2024 09:55:21 -0800 Subject: [PATCH] Add unit tests for assignIPToLink in GCS network setup Signed-off-by: Kathryn Baldauf --- internal/guest/network/netns.go | 15 +- internal/guest/network/netns_test.go | 228 +++++++++++++++++++++++++++ 2 files changed, 239 insertions(+), 4 deletions(-) create mode 100644 internal/guest/network/netns_test.go diff --git a/internal/guest/network/netns.go b/internal/guest/network/netns.go index e414e5e320..7953cf07bc 100644 --- a/internal/guest/network/netns.go +++ b/internal/guest/network/netns.go @@ -20,6 +20,13 @@ import ( "github.com/vishvananda/netns" ) +var ( + // function definitions for mocking assignIPToLink + netlinkAddrAdd = netlink.AddrAdd + netlinkRuleAdd = netlink.RuleAdd + netlinkRouteAdd = netlink.RouteAdd +) + // MoveInterfaceToNS moves the adapter with interface name `ifStr` to the network namespace // of `pid`. func MoveInterfaceToNS(ifStr string, pid int) error { @@ -219,7 +226,7 @@ func assignIPToLink(ctx context.Context, "IP": addr, }).Debugf("parsed ip address %s/%d", allocatedIP, prefixLen) ipAddr := &netlink.Addr{IPNet: addr, Label: ""} - if err := netlink.AddrAdd(link, ipAddr); err != nil { + if err := netlinkAddrAdd(link, ipAddr); err != nil { return errors.Wrapf(err, "netlink.AddrAdd(%#v, %#v) failed", link, ipAddr) } if gatewayIP == "" { @@ -242,7 +249,7 @@ func assignIPToLink(ctx context.Context, IP: gw, Mask: net.CIDRMask(ml, ml)} ipAddr2 := &netlink.Addr{IPNet: addr2, Label: ""} - if err := netlink.AddrAdd(link, ipAddr2); err != nil { + if err := netlinkAddrAdd(link, ipAddr2); err != nil { return errors.Wrapf(err, "netlink.AddrAdd(%#v, %#v) failed", link, ipAddr2) } } @@ -261,7 +268,7 @@ func assignIPToLink(ctx context.Context, rule.Src = srcNet rule.Priority = 5 - if err := netlink.RuleAdd(rule); err != nil { + if err := netlinkRuleAdd(rule); err != nil { return errors.Wrapf(err, "netlink.RuleAdd(%#v) failed", rule) } table = rule.Table @@ -274,7 +281,7 @@ func assignIPToLink(ctx context.Context, Table: table, Priority: metric, } - if err := netlink.RouteAdd(&route); err != nil { + if err := netlinkRouteAdd(&route); err != nil { return errors.Wrapf(err, "netlink.RouteAdd(%#v) failed", route) } return nil diff --git a/internal/guest/network/netns_test.go b/internal/guest/network/netns_test.go new file mode 100644 index 0000000000..eebb72d74b --- /dev/null +++ b/internal/guest/network/netns_test.go @@ -0,0 +1,228 @@ +//go:build linux +// +build linux + +package network + +import ( + "bytes" + "context" + "fmt" + "net" + "testing" + + "github.com/vishvananda/netlink" +) + +type fakeLink struct { + attr *netlink.LinkAttrs +} + +func (l *fakeLink) Attrs() *netlink.LinkAttrs { + return l.attr +} + +func (l *fakeLink) Type() string { + return "" +} + +func newFakeLink(name string, index int) *fakeLink { + attr := &netlink.LinkAttrs{ + Name: name, + Index: index, + } + return &fakeLink{ + attr: attr, + } +} + +var _ = (netlink.Link)(&fakeLink{}) + +func standardNetlinkAddrAdd(expectedIP string, prefixLen, totalMaskSize int) func(_ netlink.Link, _ *netlink.Addr) error { + return func(link netlink.Link, addr *netlink.Addr) error { + if addr.IP.String() != expectedIP { + return fmt.Errorf("expected to add address %s, instead got %s", expectedIP, addr.IP.String()) + } + expectedMask := net.CIDRMask(prefixLen, totalMaskSize) + if !bytes.Equal(addr.Mask, expectedMask) { + return fmt.Errorf("expected mask to be %s, instead got %s", expectedMask, addr.Mask) + } + return nil + } +} + +func standardNetlinkRouteAdd(gatewayIP string, table, metric int) func(_ *netlink.Route) error { + return func(route *netlink.Route) error { + if route.Gw.String() != gatewayIP { + return fmt.Errorf("expected to add gateway %s, instead got %s", gatewayIP, route.Gw.String()) + } + if route.Table != table { + return fmt.Errorf("expected to use table %d, instead got %d", table, route.Table) + } + if route.Priority != metric { + return fmt.Errorf("expected to use metric %d, instead used %d", metric, route.Priority) + } + return nil + } +} + +type assignIPToLinkTest struct { + name string + ifStr string + allocatedIP string + gatewayIP string + prefixLen uint8 + totalMaskSize int +} + +var defaultAssignIPToLinkTests = []assignIPToLinkTest{ + { + name: "ipv4 standard", + ifStr: "eth0", + allocatedIP: "192.168.0.5", + gatewayIP: "192.168.0.100", + prefixLen: uint8(24), + totalMaskSize: 32, + }, + { + name: "ipv6 standard", + ifStr: "eth0", + allocatedIP: "9541:a2d4:f0f3:18ff:c868:26ce:e9c4:30a6", + gatewayIP: "9541:a2d4:f0f3:18ff:c868:26ce:e9c4:aaaa", + prefixLen: uint8(64), + totalMaskSize: 128, + }, +} + +func Test_AssignIPToLink(t *testing.T) { + ctx := context.Background() + + for _, tt := range defaultAssignIPToLinkTests { + t.Run(tt.name, func(st *testing.T) { + link1 := newFakeLink(tt.ifStr, 0) + + netlinkAddrAdd = standardNetlinkAddrAdd(tt.allocatedIP, int(tt.prefixLen), tt.totalMaskSize) + netlinkRouteAdd = standardNetlinkRouteAdd(tt.gatewayIP, 0, 1) + + if err := assignIPToLink(ctx, tt.ifStr, 10, link1, tt.allocatedIP, tt.gatewayIP, tt.prefixLen, false, 1); err != nil { + st.Fatalf("assignIPToLink: %s", err) + } + }) + } + +} + +func Test_AssignIPToLink_No_Gateway(t *testing.T) { + ctx := context.Background() + + for _, tt := range defaultAssignIPToLinkTests { + t.Run(tt.name, func(st *testing.T) { + // remove the gateway IP set for the tests + tt.gatewayIP = "" + link1 := newFakeLink(tt.ifStr, 0) + + netlinkAddrAdd = standardNetlinkAddrAdd(tt.allocatedIP, int(tt.prefixLen), tt.totalMaskSize) + netlinkRouteAdd = standardNetlinkRouteAdd(tt.gatewayIP, 0, 1) + + if err := assignIPToLink(ctx, tt.ifStr, 10, link1, tt.allocatedIP, tt.gatewayIP, tt.prefixLen, false, 1); err != nil { + st.Fatalf("assignIPToLink: %s", err) + } + }) + } + +} + +func Test_AssignIPToLink_GatewayOutsideSubnet(t *testing.T) { + ctx := context.Background() + + var assignIPToLinkTestsGateway = []assignIPToLinkTest{ + { + name: "ipv4 standard", + ifStr: "eth0", + allocatedIP: "192.168.0.5", + gatewayIP: "10.0.0.5", + prefixLen: uint8(24), + totalMaskSize: 32, + }, + { + name: "ipv6 standard", + ifStr: "eth0", + allocatedIP: "9541:a2d4:f0f3:18ff:c868:26ce:e9c4:30a6", + gatewayIP: "337c:83ab:b4cc:d823:6b5d:6aea:f605:80c5", + prefixLen: uint8(64), + totalMaskSize: 128, + }, + } + + for _, tt := range assignIPToLinkTestsGateway { + t.Run(tt.name, func(st *testing.T) { + link1 := newFakeLink(tt.ifStr, 0) + + netlinkAddCalls := 0 + netlinkAddrAdd = func(link netlink.Link, addr *netlink.Addr) error { + expectedIP := tt.allocatedIP + expectedMask := net.CIDRMask(int(tt.prefixLen), tt.totalMaskSize) + if netlinkAddCalls != 0 { + // on the second call, we want to check for the gateway address being added + expectedIP = tt.gatewayIP + expectedMask = net.CIDRMask(tt.totalMaskSize, tt.totalMaskSize) + } + if addr.IP.String() != expectedIP { + return fmt.Errorf("expected to add address %s, instead got %s", expectedIP, addr.IP.String()) + } + if !bytes.Equal(addr.Mask, expectedMask) { + return fmt.Errorf("expected mask to be %s, instead got %s", expectedMask, addr.Mask) + } + netlinkAddCalls++ + return nil + } + + netlinkRouteAdd = standardNetlinkRouteAdd(tt.gatewayIP, 0, 1) + + if err := assignIPToLink(ctx, tt.ifStr, 10, link1, tt.allocatedIP, tt.gatewayIP, tt.prefixLen, false, 1); err != nil { + st.Fatalf("assignIPToLink: %s", err) + } + + if netlinkAddCalls < 2 { + st.Fatalf("expected to call netlink AddrAdd %d times, instead got %d times", 2, netlinkAddCalls) + } + }) + } + +} + +func Test_AssignIPToLink_EnableLowMetric(t *testing.T) { + ctx := context.Background() + table := 101 + metric := 500 + + for _, tt := range defaultAssignIPToLinkTests { + t.Run(tt.name, func(st *testing.T) { + link1 := newFakeLink(tt.ifStr, 0) + + netlinkAddrAdd = standardNetlinkAddrAdd(tt.allocatedIP, int(tt.prefixLen), tt.totalMaskSize) + netlinkRouteAdd = standardNetlinkRouteAdd(tt.gatewayIP, table, metric) + + netlinkRuleAddCalled := false + netlinkRuleAdd = func(rule *netlink.Rule) error { + netlinkRuleAddCalled = true + if rule.Src.IP.String() != tt.allocatedIP { + return fmt.Errorf("expected to add rule for address %s, instead got %s", tt.allocatedIP, rule.Src.IP.String()) + } + expectedMask := net.CIDRMask(tt.totalMaskSize, tt.totalMaskSize) + if !bytes.Equal(expectedMask, rule.Src.Mask) { + return fmt.Errorf("expected mask to be %s, instead got %s", expectedMask, rule.Src.Mask) + } + return nil + } + + if err := assignIPToLink(ctx, tt.ifStr, 10, link1, tt.allocatedIP, tt.gatewayIP, tt.prefixLen, true, metric); err != nil { + t.Fatalf("assignIPToLink: %s", err) + } + + if !netlinkRuleAddCalled { + t.Fatal("should have added a rule since enableLowMetric was set") + } + }) + } + +}