Skip to content

Commit b9c092a

Browse files
Merge pull request #50 from veraison/marshaling
Marshaling refactor
2 parents 89fde61 + 714b478 commit b9c092a

10 files changed

+343
-493
lines changed

claims_p1.go

+14-73
Original file line numberDiff line numberDiff line change
@@ -172,108 +172,49 @@ func (c *P1Claims) SetVSI(v string) error {
172172

173173
// Codecs
174174

175-
func (c *P1Claims) FromCBOR(buf []byte) error {
176-
err := c.FromUnvalidatedCBOR(buf)
177-
if err != nil {
178-
return err
179-
}
180-
181-
err = c.Validate()
182-
if err != nil {
183-
return fmt.Errorf("validation of PSA claims failed: %w", err)
184-
}
175+
// this type alias is used to prevent infinite recursion during marshaling.
176+
type p1Claims P1Claims
185177

186-
return nil
187-
}
188-
189-
func (c *P1Claims) FromUnvalidatedCBOR(buf []byte) error {
178+
func (c *P1Claims) UnmarshalCBOR(buf []byte) error {
190179
c.Profile = nil // clear profile to make sure we take it from buf
191180

192-
err := dm.Unmarshal(buf, c)
181+
// cast prevents the decoder invoking this method again
182+
err := dm.Unmarshal(buf, (*p1Claims)(c))
193183
if err != nil {
194184
return fmt.Errorf("CBOR decoding of PSA claims failed: %w", err)
195185
}
196186

197187
return nil
198188
}
199189

200-
func (c P1Claims) ToCBOR() ([]byte, error) { //nolint:gocritic
201-
err := c.Validate()
202-
if err != nil {
203-
return nil, fmt.Errorf("validation of PSA claims failed: %w", err)
204-
}
205-
206-
return c.ToUnvalidatedCBOR()
207-
}
208-
209-
func (c P1Claims) ToUnvalidatedCBOR() ([]byte, error) { //nolint:gocritic
210-
var scs ISwComponents
190+
func (c P1Claims) MarshalCBOR() ([]byte, error) { //nolint:gocritic
211191
if c.SwComponents != nil && c.SwComponents.IsEmpty() {
212-
scs = c.SwComponents
213192
c.SwComponents = nil
214193
}
215194

216-
buf, err := em.Marshal(&c)
217-
if scs != nil {
218-
c.SwComponents = scs
219-
}
220-
if err != nil {
221-
return nil, fmt.Errorf("CBOR encoding of PSA claims failed: %w", err)
222-
}
223-
224-
return buf, nil
195+
// cast prevents encoder from invoking this method again
196+
return em.Marshal((*p1Claims)(&c))
225197
}
226198

227-
func (c *P1Claims) FromJSON(buf []byte) error {
228-
err := c.FromUnvalidatedJSON(buf)
229-
if err != nil {
230-
return err
231-
}
232-
233-
err = c.Validate()
234-
if err != nil {
235-
return fmt.Errorf("validation of PSA claims failed: %w", err)
236-
}
237-
238-
return nil
239-
}
240-
241-
func (c *P1Claims) FromUnvalidatedJSON(buf []byte) error {
199+
func (c *P1Claims) UnmarshalJSON(buf []byte) error {
242200
c.Profile = nil // clear profile to make sure we take it from buf
243201

244-
err := json.Unmarshal(buf, c)
202+
// cast prevents the decoder invoking this method again
203+
err := json.Unmarshal(buf, (*p1Claims)(c))
245204
if err != nil {
246205
return fmt.Errorf("JSON decoding of PSA claims failed: %w", err)
247206
}
248207

249208
return nil
250209
}
251210

252-
func (c P1Claims) ToJSON() ([]byte, error) { //nolint:gocritic
253-
err := c.Validate()
254-
if err != nil {
255-
return nil, fmt.Errorf("validation of PSA claims failed: %w", err)
256-
}
257-
258-
return c.ToUnvalidatedJSON()
259-
}
260-
261-
func (c P1Claims) ToUnvalidatedJSON() ([]byte, error) { //nolint:gocritic
262-
var scs ISwComponents
211+
func (c P1Claims) MarshalJSON() ([]byte, error) { //nolint:gocritic
263212
if c.SwComponents != nil && c.SwComponents.IsEmpty() {
264-
scs = c.SwComponents
265213
c.SwComponents = nil
266214
}
267215

268-
buf, err := json.Marshal(&c)
269-
if scs != nil {
270-
c.SwComponents = scs
271-
}
272-
if err != nil {
273-
return nil, fmt.Errorf("JSON encoding of PSA claims failed: %w", err)
274-
}
275-
276-
return buf, nil
216+
// cast prevents encoder from invoking this method again
217+
return json.Marshal((*p1Claims)(&c))
277218
}
278219

279220
// Getters return a validated value or an error

claims_p1_test.go

+27-34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package psatoken
55

66
import (
7+
"encoding/json"
78
"os"
89
"testing"
910

@@ -93,88 +94,75 @@ func Test_P1Claims_Validate_mandatory_only_claims_and_no_swmeasurements(t *testi
9394

9495
func Test_P1Claims_ToCBOR_invalid(t *testing.T) {
9596
c := newP1Claims(false)
97+
expectedErr := `validating security lifecycle: missing mandatory claim`
9698

97-
expectedErr := `validation of PSA claims failed: validating security lifecycle: missing mandatory claim`
98-
99-
_, err := c.ToCBOR()
99+
_, err := ValidateAndEncodeClaimsToCBOR(c)
100100

101101
assert.EqualError(t, err, expectedErr)
102102
}
103103

104104
func Test_P1Claims_ToCBOR_all_claims_and_no_swmeasurements(t *testing.T) {
105105
c := mustBuildValidP1Claims(t, true, true)
106-
107106
expected := mustHexDecode(t, testEncodedP1ClaimsAllNoSwMeasurements)
108107

109-
actual, err := c.ToCBOR()
108+
actual, err := ValidateAndEncodeClaimsToCBOR(c)
110109

111110
assert.NoError(t, err)
112111
assert.Equal(t, expected, actual)
113112
}
114113

115114
func Test_P1Claims_ToCBOR_all_claims(t *testing.T) {
116115
c := mustBuildValidP1Claims(t, true, false)
117-
118116
expected := mustHexDecode(t, testEncodedP1ClaimsAll)
119117

120-
actual, err := c.ToCBOR()
118+
actual, err := ValidateAndEncodeClaimsToCBOR(c)
121119

122120
assert.NoError(t, err)
123121
assert.Equal(t, expected, actual)
124122
}
125123

126124
func Test_P1Claims_ToCBOR_mandatory_only_claims(t *testing.T) {
127125
c := mustBuildValidP1Claims(t, false, false)
128-
129126
expected := mustHexDecode(t, testEncodedP1ClaimsMandatoryOnly)
130127

131-
actual, err := c.ToCBOR()
128+
actual, err := ValidateAndEncodeClaimsToCBOR(c)
132129

133130
assert.NoError(t, err)
134131
assert.Equal(t, expected, actual)
135132
}
136133

137134
func Test_P1Claims_ToCBOR_mandatory_only_claims_and_no_swmeasurements(t *testing.T) {
138135
c := mustBuildValidP1Claims(t, false, true)
139-
140136
expected := mustHexDecode(t, testEncodedP1ClaimsMandatoryOnlyNoSwMeasurements)
141137

142-
actual, err := c.ToCBOR()
138+
actual, err := ValidateAndEncodeClaimsToCBOR(c)
143139

144140
assert.NoError(t, err)
145141
assert.Equal(t, expected, actual)
146142
}
147143

148144
func Test_P1Claims_FromCBOR_bad_input(t *testing.T) {
149145
buf := mustHexDecode(t, testNotCBOR)
146+
expectedErr := "unexpected EOF"
150147

151-
expectedErr := "CBOR decoding of PSA claims failed: unexpected EOF"
152-
153-
c := newP1Claims(false)
154-
155-
err := c.FromCBOR(buf)
148+
_, err := DecodeAndValidateClaimsFromCBOR(buf)
156149

157150
assert.EqualError(t, err, expectedErr)
158151
}
159152

160153
func Test_P1Claims_FromCBOR_missing_mandatory_claim(t *testing.T) {
161154
buf := mustHexDecode(t, testEncodedP1ClaimsMissingMandatoryNonce)
155+
expectedErr := "validating nonce: missing mandatory claim"
162156

163-
expectedErr := "validation of PSA claims failed: validating nonce: missing mandatory claim"
164-
165-
c := newP1Claims(false)
166-
167-
err := c.FromCBOR(buf)
157+
_, err := DecodeAndValidateClaimsFromCBOR(buf)
168158

169159
assert.EqualError(t, err, expectedErr)
170160
}
171161

172162
func Test_P1Claims_FromCBOR_ok_mandatory_only(t *testing.T) {
173163
buf := mustHexDecode(t, testEncodedP1ClaimsMandatoryOnly)
174164

175-
c := newP1Claims(false)
176-
177-
err := c.FromCBOR(buf)
165+
c, err := DecodeAndValidateClaimsFromCBOR(buf)
178166
assert.NoError(t, err)
179167

180168
// even if it's not physically present the profile indication is always returned
@@ -250,9 +238,13 @@ func Test_P1Claims_FromJSON_positives(t *testing.T) {
250238
buf, err := os.ReadFile(fn)
251239
require.NoError(t, err)
252240

253-
claimsSet := newP1Claims(false)
241+
claims := newP1Claims(false)
242+
243+
err = json.Unmarshal(buf, claims)
244+
require.NoError(t, err)
245+
246+
err = claims.Validate()
254247

255-
err = claimsSet.FromJSON(buf)
256248
assert.NoError(t, err, "test vector %d failed", i)
257249
}
258250
}
@@ -292,21 +284,22 @@ func Test_P1Claims_FromJSON_negatives(t *testing.T) {
292284
buf, err := os.ReadFile(fn)
293285
require.NoError(t, err)
294286

295-
var claimsSet P1Claims
287+
claims := newP1Claims(false)
288+
289+
err = json.Unmarshal(buf, claims)
290+
require.NoError(t, err)
291+
292+
err = claims.Validate()
296293

297-
err = claimsSet.FromJSON(buf)
298294
assert.Error(t, err, "test vector %d failed", i)
299295
}
300296
}
301297

302298
func TestP1Claims_FromJSON_invalid_json(t *testing.T) {
303-
tv := testNotJSON
299+
expectedErr := `unexpected end of JSON input`
304300

305-
expectedErr := `JSON decoding of PSA claims failed: unexpected end of JSON input`
306-
307-
c := newP1Claims(false)
301+
_, err := DecodeAndValidateClaimsFromJSON(testNotJSON)
308302

309-
err := c.FromJSON(tv)
310303
assert.EqualError(t, err, expectedErr)
311304
}
312305

@@ -331,7 +324,7 @@ func Test_P1Claims_ToJSON_ok(t *testing.T) {
331324
"psa-verification-service-indicator": "https://veraison.example/v1/challenge-response"
332325
}`
333326

334-
actual, err := c.ToJSON()
327+
actual, err := ValidateAndEncodeClaimsToJSON(c)
335328
assert.NoError(t, err)
336329
assert.JSONEq(t, expected, string(actual))
337330
}

claims_p2.go

+6-67
Original file line numberDiff line numberDiff line change
@@ -167,80 +167,19 @@ func (c *P2Claims) SetVSI(v string) error {
167167

168168
// Codecs
169169

170-
func (c *P2Claims) FromCBOR(buf []byte) error {
171-
err := c.FromUnvalidatedCBOR(buf)
172-
if err != nil {
173-
return err
174-
}
175-
176-
err = c.Validate()
177-
if err != nil {
178-
return fmt.Errorf("validation of PSA claims failed: %w", err)
179-
}
170+
// this type alias is used to prevent infinite recursion during marshaling.
171+
type p2Claims P2Claims
180172

181-
return nil
182-
}
183-
184-
func (c *P2Claims) FromUnvalidatedCBOR(buf []byte) error {
173+
func (c *P2Claims) UnmarshalCBOR(buf []byte) error {
185174
c.Profile = nil // clear profile to make sure we take it from buf
186175

187-
err := dm.Unmarshal(buf, c)
188-
if err != nil {
189-
return fmt.Errorf("CBOR decoding of PSA claims failed: %w", err)
190-
}
191-
192-
return nil
193-
}
194-
195-
func (c P2Claims) ToCBOR() ([]byte, error) { //nolint:gocritic
196-
err := c.Validate()
197-
if err != nil {
198-
return nil, fmt.Errorf("validation of PSA claims failed: %w", err)
199-
}
200-
201-
return c.ToUnvalidatedCBOR()
176+
return dm.Unmarshal(buf, (*p2Claims)(c))
202177
}
203178

204-
func (c P2Claims) ToUnvalidatedCBOR() ([]byte, error) { //nolint:gocritic
205-
return em.Marshal(&c)
206-
}
207-
208-
func (c *P2Claims) FromJSON(buf []byte) error {
209-
err := c.FromUnvalidatedJSON(buf)
210-
if err != nil {
211-
return err
212-
}
213-
214-
err = c.Validate()
215-
if err != nil {
216-
return fmt.Errorf("validation of PSA claims failed: %w", err)
217-
}
218-
219-
return nil
220-
}
221-
222-
func (c *P2Claims) FromUnvalidatedJSON(buf []byte) error {
179+
func (c *P2Claims) UnmarshalJSON(buf []byte) error {
223180
c.Profile = nil // clear profile to make sure we take it from buf
224181

225-
err := json.Unmarshal(buf, c)
226-
if err != nil {
227-
return fmt.Errorf("JSON decoding of PSA claims failed: %w", err)
228-
}
229-
230-
return nil
231-
}
232-
233-
func (c P2Claims) ToJSON() ([]byte, error) { //nolint:gocritic
234-
err := c.Validate()
235-
if err != nil {
236-
return nil, fmt.Errorf("validation of PSA claims failed: %w", err)
237-
}
238-
239-
return c.ToUnvalidatedJSON()
240-
}
241-
242-
func (c P2Claims) ToUnvalidatedJSON() ([]byte, error) { //nolint:gocritic
243-
return json.Marshal(&c)
182+
return json.Unmarshal(buf, (*p2Claims)(c))
244183
}
245184

246185
// Getters return a validated value or an error

0 commit comments

Comments
 (0)