Skip to content

Isolate NETLINK_NETFILTER socket behavior behind the nftables flag in runsc. #11812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/sentry/inet/inet.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ type Stack interface {
// IsSaveRestoreEnabled returns true when netstack s/r is enabled.
IsSaveRestoreEnabled() bool

// EnableNFTables enables nftables support for the stack.
EnableNFTables() error

// IsNFTablesEnabled returns true when nftables support is enabled.
IsNFTablesEnabled() bool

// Stats returns the network stats.
Stats() tcpip.Stats
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/sentry/inet/test_stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,18 @@ func (*TestStack) IsSaveRestoreEnabled() bool {
return false
}

// EnableNFTables implements Stack.
func (*TestStack) EnableNFTables() error {
// No-op.
return nil
}

// IsNFTablesEnabled implements Stack.
func (*TestStack) IsNFTablesEnabled() bool {
// No-op.
return false
}

// Stats implements Stack.
func (*TestStack) Stats() tcpip.Stats {
// No-op.
Expand Down
10 changes: 10 additions & 0 deletions pkg/sentry/socket/hostinet/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,16 @@ func (s *Stack) IsSaveRestoreEnabled() bool {
return false
}

// EnableNFTables implements inet.Stack.EnableNFTables.
func (s *Stack) EnableNFTables() error {
return fmt.Errorf("nftables is not supported for hostinet")
}

// IsNFTablesEnabled implements inet.Stack.IsNFTablesEnabled.
func (s *Stack) IsNFTablesEnabled() bool {
return false
}

// Stats implements inet.Stack.Stats.
func (s *Stack) Stats() tcpip.Stats {
return tcpip.Stats{}
Expand Down
1 change: 1 addition & 0 deletions pkg/sentry/socket/netlink/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/nlmsg",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
Expand Down
5 changes: 4 additions & 1 deletion pkg/sentry/socket/netlink/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
)
Expand Down Expand Up @@ -81,8 +83,9 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int
return nil, syserr.ErrSocketNotSupported
}

nftEnabled := inet.StackFromContext(t.Kernel().SupervisorContext()).(*netstack.Stack).Stack.IsNFTablesEnabled()
provider, ok := protocols[protocol]
if !ok {
if !ok || (!nftEnabled && protocol == linux.NETLINK_NETFILTER) {
return nil, syserr.ErrProtocolNotSupported
}

Expand Down
14 changes: 14 additions & 0 deletions pkg/sentry/socket/netstack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ func (s *Stack) IsSaveRestoreEnabled() bool {
return s.Stack.IsSaveRestoreEnabled()
}

// EnableNFTables enables nftables support for the stack.
func (s *Stack) EnableNFTables() error {
s.Stack.EnableNFTables()
return nil
}

// IsNFTablesEnabled implements inet.Stack.IsNFTablesEnabled.
func (s *Stack) IsNFTablesEnabled() bool {
if s.Stack == nil {
return false
}
return s.Stack.IsNFTablesEnabled()
}

// Destroy implements inet.Stack.Destroy.
func (s *Stack) Destroy() {
s.Stack.Close()
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/nftables/nft_comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ func (op comparison) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *
}
if !result {
// Comparison is false, so break from the rule.
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
}
}
4 changes: 2 additions & 2 deletions pkg/tcpip/nftables/nft_metaload.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (op metaLoad) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Ru
// Netfilter (Family) Protocol (8-bit, single byte).
case linux.NFT_META_NFPROTO:
family := rule.chain.GetAddressFamily()
target = []byte{family.Protocol()}
target = []byte{AfProtocol(family)}

// L4 Transport Layer Protocol (8-bit, single byte).
case linux.NFT_META_L4PROTO:
Expand Down Expand Up @@ -225,7 +225,7 @@ func (op metaLoad) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Ru

// Breaks if could not retrieve meta data.
if target == nil {
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
return
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/nftables/nft_metaset.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ func (op metaSet) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Rul
}

// Breaks if could not set the meta data.
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
return
}
2 changes: 1 addition & 1 deletion pkg/tcpip/nftables/nft_payload_load.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (op payloadLoad) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule

// Breaks if could not retrieve packet data.
if payload == nil || len(payload) < int(op.offset+op.blen) {
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
return
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/nftables/nft_payload_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (op payloadSet) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *

// Breaks if could not retrieve packet data.
if payload == nil || len(payload) < int(op.offset+op.blen) {
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
return
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/nftables/nft_ranged.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,6 @@ func (op ranged) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Rule
// Determines the comparison result depending on the operator.
if (d1 >= 0 && d2 <= 0) != (op.rop == linux.NFT_RANGE_EQ) {
// Comparison is false, so break from the rule.
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
}
}
2 changes: 1 addition & 1 deletion pkg/tcpip/nftables/nft_route.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (op route) evaluate(regs *registerSet, pkt *stack.PacketBuffer, rule *Rule)

// Breaks if could not retrieve target data.
if target == nil {
regs.verdict = Verdict{Code: VC(linux.NFT_BREAK)}
regs.verdict = stack.NFVerdict{Code: VC(linux.NFT_BREAK)}
return
}

Expand Down
95 changes: 75 additions & 20 deletions pkg/tcpip/nftables/nftables.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,86 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)

//
// Interface-Related Methods
//

// CheckPrerouting checks at the Prerouting hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckPrerouting(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFPrerouting)
}

// CheckInput checks at the Input hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckInput(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFInput)
}

// CheckForward checks at the Forward hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckForward(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFForward)
}

// CheckOutput checks at the Output hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckOutput(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFOutput)
}

// CheckPostrouting checks at the Postrouting hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckPostrouting(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFPostrouting)
}

// CheckIngress checks at the Ingress hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckIngress(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFIngress)
}

// CheckEgress checks at the Egress hook if the packet should continue traversing the stack.
func (nf *NFTables) CheckEgress(pkt *stack.PacketBuffer, af stack.AddressFamily) bool {
return nf.checkHook(pkt, af, stack.NFEgress)
}

// checkHook returns true if the packet should continue traversing the stack or false
// if the packet should be dropped.
func (nf *NFTables) checkHook(pkt *stack.PacketBuffer, af stack.AddressFamily, hook stack.NFHook) bool {
v, err := nf.EvaluateHook(af, hook, pkt)

if err != nil {
return false
}

return v.Code == VC(linux.NF_ACCEPT)
}

//
// Core Evaluation Functions
//

// EvaluateHook evaluates a packet using the rules of the given hook for the
// given address family, returning a netfilter verdict and modifying the packet
// in place.
// Returns an error if address family or hook is invalid or they don't match.
// TODO(b/345684870): Consider removing error case if we never return an error.
func (nf *NFTables) EvaluateHook(family AddressFamily, hook Hook, pkt *stack.PacketBuffer) (Verdict, error) {
func (nf *NFTables) EvaluateHook(family stack.AddressFamily, hook stack.NFHook, pkt *stack.PacketBuffer) (stack.NFVerdict, error) {
// Note: none of the other evaluate functions are public because they require
// jumping to different chains in the same table, so all chains, rules, and
// operations must be tied to a table. Thus, calling evaluate for standalone
// chains, rules, or operations can be misleading and dangerous.

// Ensures address family is valid.
if err := validateAddressFamily(family); err != nil {
return Verdict{}, err
return stack.NFVerdict{}, err
}

// Ensures hook is valid.
if err := validateHook(hook, family); err != nil {
return Verdict{}, err
return stack.NFVerdict{}, err
}

// Immediately accept if there are no base chains for the specified hook.
if nf.filters[family] == nil || nf.filters[family].hfStacks[hook] == nil ||
len(nf.filters[family].hfStacks[hook].baseChains) == 0 {
return Verdict{Code: VC(linux.NF_ACCEPT)}, nil
return stack.NFVerdict{Code: VC(linux.NF_ACCEPT)}, nil
}

regs := newRegisterSet()
Expand All @@ -63,7 +118,7 @@ func (nf *NFTables) EvaluateHook(family AddressFamily, hook Hook, pkt *stack.Pac

err := bc.evaluate(&regs, pkt)
if err != nil {
return Verdict{}, err
return stack.NFVerdict{}, err
}

// Terminates immediately on netfilter terminal verdicts.
Expand All @@ -78,9 +133,9 @@ func (nf *NFTables) EvaluateHook(family AddressFamily, hook Hook, pkt *stack.Pac
switch regs.Verdict().Code {
case VC(linux.NFT_CONTINUE), VC(linux.NFT_RETURN):
if bc.GetBaseChainInfo().PolicyDrop {
return Verdict{Code: VC(linux.NF_DROP)}, nil
return stack.NFVerdict{Code: VC(linux.NF_DROP)}, nil
}
return Verdict{Code: VC(linux.NF_ACCEPT)}, nil
return stack.NFVerdict{Code: VC(linux.NF_ACCEPT)}, nil
}

panic(fmt.Sprintf("unexpected verdict from hook evaluation: %s", VerdictCodeToString(regs.Verdict().Code)))
Expand Down Expand Up @@ -184,14 +239,14 @@ func NewNFTables(clock tcpip.Clock, rng rand.RNG) *NFTables {

// Flush clears entire ruleset and all data for all address families.
func (nf *NFTables) Flush() {
for family := range NumAFs {
for family := range stack.NumAFs {
nf.filters[family] = nil
}
}

// FlushAddressFamily clears ruleset and all data for the given address family,
// returning an error if the address family is invalid.
func (nf *NFTables) FlushAddressFamily(family AddressFamily) error {
func (nf *NFTables) FlushAddressFamily(family stack.AddressFamily) error {
// Ensures address family is valid.
if err := validateAddressFamily(family); err != nil {
return err
Expand All @@ -202,7 +257,7 @@ func (nf *NFTables) FlushAddressFamily(family AddressFamily) error {
}

// GetTable validates the inputs and gets a table if it exists, error otherwise.
func (nf *NFTables) GetTable(family AddressFamily, tableName string) (*Table, error) {
func (nf *NFTables) GetTable(family stack.AddressFamily, tableName string) (*Table, error) {
// Ensures address family is valid.
if err := validateAddressFamily(family); err != nil {
return nil, err
Expand Down Expand Up @@ -232,7 +287,7 @@ func (nf *NFTables) GetTable(family AddressFamily, tableName string) (*Table, er
// Note: if the table already exists, the existing table is returned without any
// modifications.
// Note: Table initialized as not dormant.
func (nf *NFTables) AddTable(family AddressFamily, name string, comment string,
func (nf *NFTables) AddTable(family stack.AddressFamily, name string, comment string,
errorOnDuplicate bool) (*Table, error) {
// Ensures address family is valid.
if err := validateAddressFamily(family); err != nil {
Expand All @@ -245,7 +300,7 @@ func (nf *NFTables) AddTable(family AddressFamily, name string, comment string,
family: family,
nftState: nf,
tables: make(map[string]*Table),
hfStacks: make(map[Hook]*hookFunctionStack),
hfStacks: make(map[stack.NFHook]*hookFunctionStack),
}
}

Expand Down Expand Up @@ -278,14 +333,14 @@ func (nf *NFTables) AddTable(family AddressFamily, name string, comment string,
// but also returns an error if a table by the same name already exists.
// Note: this interface mirrors the difference between the create and add
// commands within the nft binary.
func (nf *NFTables) CreateTable(family AddressFamily, name string, comment string) (*Table, error) {
func (nf *NFTables) CreateTable(family stack.AddressFamily, name string, comment string) (*Table, error) {
return nf.AddTable(family, name, comment, true)
}

// DeleteTable deletes the specified table from the NFTables object returning
// true if the table was deleted and false if the table doesn't exist. Returns
// an error if the address family is invalid.
func (nf *NFTables) DeleteTable(family AddressFamily, tableName string) (bool, error) {
func (nf *NFTables) DeleteTable(family stack.AddressFamily, tableName string) (bool, error) {
// Ensures address family is valid.
if err := validateAddressFamily(family); err != nil {
return false, err
Expand All @@ -308,7 +363,7 @@ func (nf *NFTables) DeleteTable(family AddressFamily, tableName string) (bool, e
}

// GetChain validates the inputs and gets a chain if it exists, error otherwise.
func (nf *NFTables) GetChain(family AddressFamily, tableName string, chainName string) (*Chain, error) {
func (nf *NFTables) GetChain(family stack.AddressFamily, tableName string, chainName string) (*Chain, error) {
// Gets and checks the table.
t, err := nf.GetTable(family, tableName)
if err != nil {
Expand All @@ -326,7 +381,7 @@ func (nf *NFTables) GetChain(family AddressFamily, tableName string, chainName s
// Note: if the chain already exists, the existing chain is returned without any
// modifications.
// Note: if the chain is not a base chain, info should be nil.
func (nf *NFTables) AddChain(family AddressFamily, tableName string, chainName string, info *BaseChainInfo, comment string, errorOnDuplicate bool) (*Chain, error) {
func (nf *NFTables) AddChain(family stack.AddressFamily, tableName string, chainName string, info *BaseChainInfo, comment string, errorOnDuplicate bool) (*Chain, error) {
// Gets and checks the table.
t, err := nf.GetTable(family, tableName)
if err != nil {
Expand All @@ -341,14 +396,14 @@ func (nf *NFTables) AddChain(family AddressFamily, tableName string, chainName s
// chain by the same name already exists.
// Note: this interface mirrors the difference between the create and add
// commands within the nft binary.
func (nf *NFTables) CreateChain(family AddressFamily, tableName string, chainName string, info *BaseChainInfo, comment string) (*Chain, error) {
func (nf *NFTables) CreateChain(family stack.AddressFamily, tableName string, chainName string, info *BaseChainInfo, comment string) (*Chain, error) {
return nf.AddChain(family, tableName, chainName, info, comment, true)
}

// DeleteChain deletes the specified chain from the NFTables object returning
// true if the chain was deleted and false if the chain doesn't exist. Returns
// an error if the address family is invalid or the table doesn't exist.
func (nf *NFTables) DeleteChain(family AddressFamily, tableName string, chainName string) (bool, error) {
func (nf *NFTables) DeleteChain(family stack.AddressFamily, tableName string, chainName string) (bool, error) {
// Gets and checks the table.
t, err := nf.GetTable(family, tableName)
if err != nil {
Expand All @@ -373,7 +428,7 @@ func (t *Table) GetName() string {
}

// GetAddressFamily returns the address family of the table.
func (t *Table) GetAddressFamily() AddressFamily {
func (t *Table) GetAddressFamily() stack.AddressFamily {
return t.afFilter.family
}

Expand Down Expand Up @@ -486,7 +541,7 @@ func (c *Chain) GetName() string {
}

// GetAddressFamily returns the address family of the chain.
func (c *Chain) GetAddressFamily() AddressFamily {
func (c *Chain) GetAddressFamily() stack.AddressFamily {
return c.table.GetAddressFamily()
}

Expand Down
Loading