Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 185 additions & 21 deletions internal/store/postgres/policy_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
"errors"
"fmt"
"strings"

"github.com/raystack/frontier/internal/bootstrap/schema"
"time"

"github.com/doug-martin/goqu/v9"
"github.com/jmoiron/sqlx"
"github.com/raystack/frontier/core/namespace"
"github.com/raystack/frontier/core/policy"
"github.com/raystack/frontier/internal/bootstrap/schema"
"github.com/raystack/frontier/pkg/auditrecord"
"github.com/raystack/frontier/pkg/db"
)

Expand Down Expand Up @@ -199,15 +201,42 @@ func (r PolicyRepository) Upsert(ctx context.Context, pol policy.Policy) (policy
"principal_type": pol.PrincipalType,
"metadata": marshaledMetadata,
}).OnConflict(goqu.DoUpdate("role_id, resource_id, resource_type, principal_id, principal_type", goqu.Record{
"metadata": marshaledMetadata,
"metadata": marshaledMetadata,
"updated_at": goqu.L("now()"),
})).Returning(&PolicyCols{}).ToSQL()
if err != nil {
return policy.Policy{}, fmt.Errorf("%w: %w", queryErr, err)
}

// Check if policy exists before upsert
_, exists := r.getPolicyByConstraint(ctx, pol)

var policyDB Policy
if err = r.dbc.WithTimeout(ctx, TABLE_POLICIES, "Upsert", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&policyDB)
if err = r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
return r.dbc.WithTimeout(ctx, TABLE_POLICIES, "Upsert", func(ctx context.Context) error {
if err := tx.QueryRowxContext(ctx, query, params...).StructScan(&policyDB); err != nil {
return err
}

var (
event auditrecord.Event
timestamp time.Time
additionalMetadata map[string]any
)
if exists {
event = auditrecord.PolicyUpdatedEvent
timestamp = policyDB.UpdatedAt
additionalMetadata = map[string]any{
"updated_metadata": pol.Metadata,
}
} else {
event = auditrecord.PolicyCreatedEvent
timestamp = policyDB.CreatedAt
}

auditRecord := r.buildPolicyAuditRecord(ctx, tx, event, policyDB, timestamp, additionalMetadata)
return InsertAuditRecordInTx(ctx, tx, auditRecord)
})
}); err != nil {
err = checkPostgresError(err)
switch {
Expand All @@ -225,6 +254,13 @@ func (r PolicyRepository) Update(ctx context.Context, toUpdate policy.Policy) (s
if strings.TrimSpace(toUpdate.ID) == "" {
return "", policy.ErrInvalidID
}

// Fetch existing policy for audit record
existingPolicy, err := r.Get(ctx, toUpdate.ID)
if err != nil {
return "", err
}

marshaledMetadata, err := json.Marshal(toUpdate.Metadata)
if err != nil {
return "", fmt.Errorf("%w: %s", parseErr, err)
Expand All @@ -236,14 +272,32 @@ func (r PolicyRepository) Update(ctx context.Context, toUpdate policy.Policy) (s
"updated_at": goqu.L("now()"),
}).Where(goqu.Ex{
"id": toUpdate.ID,
}).Returning("id").ToSQL()
}).Returning("id", "updated_at").ToSQL()
if err != nil {
return "", fmt.Errorf("%w: %s", queryErr, err)
}

var policyID string
if err = r.dbc.WithTimeout(ctx, TABLE_POLICIES, "Update", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).Scan(&policyID)
var updatedAt time.Time
if err = r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
return r.dbc.WithTimeout(ctx, TABLE_POLICIES, "Update", func(ctx context.Context) error {
if err := tx.QueryRowxContext(ctx, query, params...).Scan(&policyID, &updatedAt); err != nil {
return err
}

policyDB := Policy{
ID: existingPolicy.ID,
RoleID: existingPolicy.RoleID,
ResourceID: existingPolicy.ResourceID,
ResourceType: existingPolicy.ResourceType,
PrincipalID: existingPolicy.PrincipalID,
PrincipalType: existingPolicy.PrincipalType,
}
auditRecord := r.buildPolicyAuditRecord(ctx, tx, auditrecord.PolicyUpdatedEvent, policyDB, updatedAt, map[string]any{
"updated_metadata": toUpdate.Metadata,
})
return InsertAuditRecordInTx(ctx, tx, auditRecord)
})
}); err != nil {
err = checkPostgresError(err)
switch {
Expand All @@ -264,20 +318,35 @@ func (r PolicyRepository) Update(ctx context.Context, toUpdate policy.Policy) (s
}

func (r PolicyRepository) Delete(ctx context.Context, id string) error {
query, params, err := dialect.Delete(TABLE_POLICIES).Where(
goqu.Ex{
"id": id,
},
).ToSQL()
// Fetch policy for audit record
existingPolicy, err := r.Get(ctx, id)
if err != nil {
return fmt.Errorf("%w: %s", queryErr, err)
}

if err = r.dbc.WithTimeout(ctx, TABLE_POLICIES, "Delete", func(ctx context.Context) error {
if _, err = r.dbc.DB.ExecContext(ctx, query, params...); err != nil {
return err
}
return nil
return err
}

if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
return r.dbc.WithTimeout(ctx, TABLE_POLICIES, "Delete", func(ctx context.Context) error {
deleteQuery, deleteParams, err := dialect.Delete(TABLE_POLICIES).
Where(goqu.Ex{"id": id}).
ToSQL()
if err != nil {
return fmt.Errorf("%w: %w", queryErr, err)
}
if _, err := tx.ExecContext(ctx, deleteQuery, deleteParams...); err != nil {
return err
}

policyDB := Policy{
ID: existingPolicy.ID,
RoleID: existingPolicy.RoleID,
ResourceID: existingPolicy.ResourceID,
ResourceType: existingPolicy.ResourceType,
PrincipalID: existingPolicy.PrincipalID,
PrincipalType: existingPolicy.PrincipalType,
}
auditRecord := r.buildPolicyAuditRecord(ctx, tx, auditrecord.PolicyDeletedEvent, policyDB, time.Now(), nil)
return InsertAuditRecordInTx(ctx, tx, auditRecord)
})
}); err != nil {
err = checkPostgresError(err)
switch {
Expand Down Expand Up @@ -398,3 +467,98 @@ func (r PolicyRepository) OrgMemberCount(ctx context.Context, id string) (policy

return result, nil
}

// buildPolicyAuditRecord builds an audit record for policy events
func (r PolicyRepository) buildPolicyAuditRecord(ctx context.Context, tx *sqlx.Tx, event auditrecord.Event, pol Policy, timestamp time.Time, additionalMetadata map[string]any) AuditRecord {
orgID, resourceName := r.getResourceInfo(ctx, tx, pol.ResourceType, pol.ResourceID)

targetMetadata := map[string]any{
"role_id": pol.RoleID,
"principal_id": pol.PrincipalID,
"principal_type": pol.PrincipalType,
}
for k, v := range additionalMetadata {
targetMetadata[k] = v
}

return BuildAuditRecord(
ctx,
event,
AuditResource{
ID: pol.ResourceID,
Type: mapResourceTypeToAuditType(pol.ResourceType),
Name: resourceName,
},
&AuditTarget{
ID: pol.ID,
Type: auditrecord.PolicyType,
Metadata: targetMetadata,
},
orgID,
nil,
timestamp,
)
}

// getPolicyByConstraint fetches a policy by unique constraint fields
// Returns the policy and true if found, empty policy and false if not found
func (r PolicyRepository) getPolicyByConstraint(ctx context.Context, pol policy.Policy) (Policy, bool) {
query, params, _ := dialect.From(TABLE_POLICIES).
Select("id", "resource_type", "resource_id", "principal_id", "principal_type", "role_id").
Where(goqu.Ex{
"role_id": pol.RoleID,
"resource_id": pol.ResourceID,
"resource_type": pol.ResourceType,
"principal_id": pol.PrincipalID,
"principal_type": pol.PrincipalType,
}).
Limit(1).
ToSQL()

var existing Policy
if err := r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&existing); err != nil {
return Policy{}, false
}
return existing, true
}

// getResourceInfo fetches org ID and resource name based on resource type
func (r PolicyRepository) getResourceInfo(ctx context.Context, tx *sqlx.Tx, resourceType, resourceID string) (string, string) {
var orgID, resourceName string
switch resourceType {
case schema.OrganizationNamespace:
orgID = resourceID
orgQuery, orgParams, _ := dialect.From(TABLE_ORGANIZATIONS).
Select("title").
Where(goqu.Ex{"id": resourceID}).
ToSQL()
_ = tx.QueryRowContext(ctx, orgQuery, orgParams...).Scan(&resourceName)
case schema.ProjectNamespace:
projQuery, projParams, _ := dialect.From(TABLE_PROJECTS).
Select("org_id", "title").
Where(goqu.Ex{"id": resourceID}).
ToSQL()
_ = tx.QueryRowContext(ctx, projQuery, projParams...).Scan(&orgID, &resourceName)
case schema.GroupNamespace:
grpQuery, grpParams, _ := dialect.From(TABLE_GROUPS).
Select("org_id", "title").
Where(goqu.Ex{"id": resourceID}).
ToSQL()
_ = tx.QueryRowContext(ctx, grpQuery, grpParams...).Scan(&orgID, &resourceName)
}
return orgID, resourceName
}

// mapResourceTypeToAuditType maps resource namespace to audit entity type
func mapResourceTypeToAuditType(resourceType string) auditrecord.EntityType {
switch resourceType {
case schema.OrganizationNamespace:
return auditrecord.OrganizationType
case schema.ProjectNamespace:
return auditrecord.ProjectType
case schema.GroupNamespace:
return auditrecord.GroupType
default:
return auditrecord.EntityType(resourceType)
}
}
8 changes: 8 additions & 0 deletions pkg/auditrecord/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ const (
RoleCreatedEvent Event = "role.created"
RoleUpdatedEvent Event = "role.updated"

// Policy Events
PolicyCreatedEvent Event = "policy.created"
PolicyUpdatedEvent Event = "policy.updated"
PolicyDeletedEvent Event = "policy.deleted"

// Session Events
SessionRevokedEvent Event = "session.revoked"

SystemActor = "system"

// Entity Types (used in Resource.Type and Target.Type)
OrganizationType EntityType = "organization"
ProjectType EntityType = "project"
GroupType EntityType = "group"
UserType EntityType = "user"
RoleType EntityType = "role"
PolicyType EntityType = "policy"
ServiceUserType EntityType = "serviceuser"
InvitationType EntityType = "invitation"
KycType EntityType = "kyc"
Expand Down
Loading