Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 133 additions & 29 deletions caveat_set.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package macaroon

import (
"bytes"
"encoding/json"
"errors"
"fmt"
Expand All @@ -11,7 +12,22 @@ import (

// CaveatSet is how a set of caveats is serailized/encoded.
type CaveatSet struct {
Caveats []Caveat
// The caveats encoded as necessary for macaroon signing/verification.
// Access this via the PackedCaveats() method. When signing, each caveat is
// encoded as a msgpack array containing the caveat type and the caveat body
// (the same as a caveat set containing the one caveat). These encodings are
// preserved from caveats sets we decode. This accounts for any variability
// from different msgpack libraries in different languages.
packedCaveats [][]byte

// it's possible to build a CaveatSet that can't be msgpack encoded. We
// track any error here so we can return it when something tries to get the
// msgpack representation. This is mostly a thing when JSON decoding
// unregistered caveat types.
packErr error

// Decoded caveats. Access this via the Caveats() method.
caveats []Caveat
}

var (
Expand All @@ -20,12 +36,14 @@ var (
_ msgpack.Marshaler = (*CaveatSet)(nil)
)

// Create a new CaveatSet comprised of the specified caveats.
// NewCaveatSet creates a new CaveatSet comprised of the specified caveats.
func NewCaveatSet(caveats ...Caveat) *CaveatSet {
return &CaveatSet{append([]Caveat{}, caveats...)}
c := &CaveatSet{caveats: []Caveat{}, packedCaveats: [][]byte{}}
c.Add(caveats...)
return c
}

// Decodes a set of serialized caveats.
// DecodeCaveats decodes a set of serialized caveats.
func DecodeCaveats(buf []byte) (*CaveatSet, error) {
cavs := new(CaveatSet)

Expand All @@ -36,6 +54,16 @@ func DecodeCaveats(buf []byte) (*CaveatSet, error) {
return cavs, nil
}

// Caveats are the decoded caveats.
func (c *CaveatSet) Caveats() []Caveat {
return c.caveats
}

// PackedCaveats are the caveats, msgpack encoded for signing/verification.
func (c *CaveatSet) PackedCaveats() ([][]byte, error) {
return c.packedCaveats, c.packErr
}

// Validates that the caveat set permits the specified accesses.
func (c *CaveatSet) Validate(accesses ...Access) error {
return Validate(c, accesses...)
Expand All @@ -58,7 +86,7 @@ func Validate[A Access](cs *CaveatSet, accesses ...A) error {

func (c *CaveatSet) validateAccess(access Access) error {
var err error
for _, caveat := range c.Caveats {
for _, caveat := range c.caveats {
if IsAttestation(caveat) {
continue
}
Expand All @@ -72,7 +100,7 @@ func (c *CaveatSet) validateAccess(access Access) error {
// GetCaveats gets any caveats of type T, including those nested within
// IfPresent caveats.
func GetCaveats[T Caveat](c *CaveatSet) (ret []T) {
for _, cav := range c.Caveats {
for _, cav := range c.caveats {
if typed, ok := cav.(T); ok {
ret = append(ret, typed)
}
Expand All @@ -89,18 +117,25 @@ func (c CaveatSet) MarshalMsgpack() ([]byte, error) {
return encode(c)
}

// cavPrefix is the msgpack tag indicating an array of length 2.
const cavPrefix = byte(0x92)

// Implements msgpack.CustomEncoder
func (c CaveatSet) EncodeMsgpack(enc *msgpack.Encoder) error {
if err := enc.EncodeArrayLen(len(c.Caveats) * 2); err != nil {
return err
if c.packErr != nil {
return c.packErr
}

for _, cav := range c.Caveats {
if err := enc.EncodeUint(uint64(cav.CaveatType())); err != nil {
return err
}
// TODO: resize enc buffer, since we know how much we're going to write?

if err := enc.EncodeArrayLen(len(c.packedCaveats) * 2); err != nil {
return err
}

if err := enc.Encode(cav); err != nil {
for _, b := range c.packedCaveats {
// each CaveatBytes is itself a caveat set (msgpack array len=2). Skip
// the array tag when encoding them all together.
if err := enc.Encode(msgpack.RawMessage(b[1:])); err != nil {
return err
}
}
Expand All @@ -118,37 +153,94 @@ func (c *CaveatSet) DecodeMsgpack(dec *msgpack.Decoder) error {
return errors.New("bad caveat container")
}

nCavs := aLen / 2
c.caveats = make([]Caveat, aLen/2)
c.packedCaveats = make([][]byte, aLen/2)

if c.Caveats == nil {
c.Caveats = make([]Caveat, 0, nCavs)
}
for i := 0; i < aLen/2; i++ {
rawTyp, err := dec.DecodeRaw()
if err != nil {
return err
}

for i := 0; i < nCavs; i++ {
t, err := dec.DecodeUint()
var typ CaveatType
if err := msgpack.Unmarshal(rawTyp, &typ); err != nil {
return err
}

rawCav, err := dec.DecodeRaw()
if err != nil {
return err
}

cav := typeToCaveat(CaveatType(t))
if err := dec.Decode(cav); err != nil {
c.caveats[i] = typeToCaveat(CaveatType(typ))
if err := msgpack.Unmarshal(rawCav, c.caveats[i]); err != nil {
return err
}

c.Caveats = append(c.Caveats, cav)
c.packedCaveats[i] = make([]byte, 0, 1+len(rawTyp)+len(rawCav))
c.packedCaveats[i] = append(c.packedCaveats[i], cavPrefix)
c.packedCaveats[i] = append(c.packedCaveats[i], rawTyp...)
c.packedCaveats[i] = append(c.packedCaveats[i], rawCav...)
}

return nil
}

func (c *CaveatSet) Add(caveats ...Caveat) {
c.caveats = append(c.caveats, caveats...)

if c.packErr != nil {
return
}

for _, cav := range caveats {
packed, err := packCaveat(cav)
if err != nil {
c.packedCaveats = nil
c.packErr = err

return
}

c.packedCaveats = append(c.packedCaveats, packed)
}
}

func (c *CaveatSet) addWithPacked(cav Caveat, packed []byte) {
c.caveats = append(c.caveats, cav)
if c.packErr == nil {
c.packedCaveats = append(c.packedCaveats, packed)
}
}

func packCaveat(cav Caveat) ([]byte, error) {
enc := msgpack.GetEncoder()
defer msgpack.PutEncoder(enc)

var buf bytes.Buffer
configEncoder(enc, &buf)

if err := enc.Encode(msgpack.RawMessage([]byte{cavPrefix})); err != nil {
return nil, err
}
if err := enc.EncodeUint(uint64(cav.CaveatType())); err != nil {
return nil, err
}
if err := enc.Encode(cav); err != nil {
return nil, err
}

return buf.Bytes(), nil
}

func (c CaveatSet) MarshalJSON() ([]byte, error) {
var (
jcavs = make([]jsonCaveat, len(c.Caveats))
jcavs = make([]jsonCaveat, len(c.caveats))
err error
)

for i := range c.Caveats {
ct := c.Caveats[i].CaveatType()
for i := range c.caveats {
ct := c.caveats[i].CaveatType()
cts := caveatTypeToString(ct)
if cts == "" {
return nil, fmt.Errorf("unregistered caveat type: %d", ct)
Expand All @@ -158,7 +250,7 @@ func (c CaveatSet) MarshalJSON() ([]byte, error) {
Type: cts,
}

if jcavs[i].Body, err = json.Marshal(c.Caveats[i]); err != nil {
if jcavs[i].Body, err = json.Marshal(c.caveats[i]); err != nil {
return nil, err
}
}
Expand All @@ -173,14 +265,26 @@ func (c *CaveatSet) UnmarshalJSON(b []byte) error {
return err
}

c.Caveats = make([]Caveat, len(jcavs))
c.caveats = make([]Caveat, 0, len(jcavs))
c.packedCaveats = make([][]byte, 0, len(jcavs))
for i := range jcavs {
t := caveatTypeFromString(jcavs[i].Type)

c.Caveats[i] = typeToCaveat(t)
if err := json.Unmarshal(jcavs[i].Body, &c.Caveats[i]); err != nil {
cav := typeToCaveat(t)
if err := json.Unmarshal(jcavs[i].Body, &cav); err != nil {
return err
}
c.caveats = append(c.caveats, cav)

if c.packErr == nil {
if packed, err := packCaveat(cav); err != nil {
c.packErr = err
c.packedCaveats = nil
} else {
c.packedCaveats = append(c.packedCaveats, packed)
}
}

}

return nil
Expand Down
12 changes: 6 additions & 6 deletions caveat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ func TestCaveatRegistry(t *testing.T) {
)

assert.NoError(t, json.Unmarshal(j1, cs))
assert.Equal(t, 1, len(cs.Caveats))
assert.Equal(t, c, cs.Caveats[0])
assert.Equal(t, 1, len(cs.Caveats()))
assert.Equal(t, c, cs.Caveats()[0])

RegisterCaveatJSONAlias(cavTestParentResource, "Foobar")
t.Cleanup(func() { unegisterCaveatJSONAlias("Foobar") })

assert.NoError(t, json.Unmarshal(j1, cs))
assert.Equal(t, 1, len(cs.Caveats))
assert.Equal(t, c, cs.Caveats[0])
assert.Equal(t, 1, len(cs.Caveats()))
assert.Equal(t, c, cs.Caveats()[0])

assert.NoError(t, json.Unmarshal(j2, cs))
assert.Equal(t, 1, len(cs.Caveats))
assert.Equal(t, c, cs.Caveats[0])
assert.Equal(t, 1, len(cs.Caveats()))
assert.Equal(t, c, cs.Caveats()[0])
}
10 changes: 5 additions & 5 deletions caveats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ func TestUnregisteredCaveatJSON(t *testing.T) {
cs2 := NewCaveatSet()
err = json.Unmarshal(b, cs2)
assert.NoError(t, err)
assert.Equal(t, 1, len(cs2.Caveats))
assert.Equal(t, 1, len(cs2.Caveats()))

uc, ok := cs2.Caveats[0].(*UnregisteredCaveat)
uc, ok := cs2.Caveats()[0].(*UnregisteredCaveat)
assert.True(t, ok)
assert.Equal(t, cavMyUnregistered, uc.Type)

Expand Down Expand Up @@ -207,9 +207,9 @@ func TestUnregisteredCaveatMsgpack(t *testing.T) {

cs2, err := DecodeCaveats(b)
assert.NoError(t, err)
assert.Equal(t, 1, len(cs2.Caveats))
assert.Equal(t, 1, len(cs2.Caveats()))

uc, ok := cs2.Caveats[0].(*UnregisteredCaveat)
uc, ok := cs2.Caveats()[0].(*UnregisteredCaveat)
assert.True(t, ok)
assert.Equal(t, cavMyUnregistered, uc.Type)

Expand All @@ -235,7 +235,7 @@ func TestUnregisteredCaveatMsgpack(t *testing.T) {

cs3, err := DecodeCaveats(b2)
assert.NoError(t, err)
assert.Equal(t, 1, len(cs3.Caveats))
assert.Equal(t, 1, len(cs3.Caveats()))
mucs := GetCaveats[*myUnregistered](cs3)
assert.Equal(t, 1, len(mucs))
assert.Equal(t, c, mucs[0])
Expand Down
2 changes: 1 addition & 1 deletion cid.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ func dischargeTicket(ka EncryptionKey, location string, ticket []byte, issueProo
return nil, nil, err
}

return tWire.Caveats.Caveats, dm, nil
return tWire.Caveats.Caveats(), dm, nil
}
6 changes: 3 additions & 3 deletions internal/test-vectors/test_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func main() {
}
v.KID = keyFingerprint(v.Key)

for _, c := range caveats.Caveats {
for _, c := range caveats.Caveats() {
m, _ := macaroon.New(v.KID, v.Location, v.Key)

// put attestations in discharge tokens
Expand All @@ -57,7 +57,7 @@ func main() {
aBaseTok, _ := aBase.Encode()
aBaseHdr := macaroon.ToAuthorizationHeader(otherTok, aBaseTok, otherTok)
v.Attenuation[aBaseHdr] = map[string]string{}
for _, c := range caveats.Caveats {
for _, c := range caveats.Caveats() {
cpy := ptr(*aBase)
cpy.UnsafeCaveats = *macaroon.NewCaveatSet()
cpy.Add(c)
Expand All @@ -66,7 +66,7 @@ func main() {
v.Attenuation[aBaseHdr][base64.StdEncoding.EncodeToString(cavsPacked)] = macaroon.ToAuthorizationHeader(otherTok, cpyEnc, otherTok)
}

for _, c := range caveats.Caveats {
for _, c := range caveats.Caveats() {
v.Caveats[c.Name()] = pack(c)
}

Expand Down
1 change: 0 additions & 1 deletion internal/test-vectors/test_vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
)

func TestCaveatSerialization(t *testing.T) {

b, err := json.Marshal(caveats)
assert.NoError(t, err)

Expand Down
Loading