@@ -10,6 +10,7 @@ import (
1010
1111 "github.com/lib/pq"
1212 "github.com/xataio/pgroll/pkg/db"
13+ "github.com/xataio/pgroll/pkg/schema"
1314)
1415
1516// DBAction is an interface for common database actions
@@ -842,75 +843,46 @@ func (a *setReplicaIdentityAction) Execute(ctx context.Context) error {
842843}
843844
844845type alterReferencesAction struct {
845- conn db.DB
846- table string
847- column string
846+ conn db.DB
847+ referencedBy map [string ][]* schema.ReferencedBy
848+ table string
849+ column string
848850}
849851
850- func NewAlterReferencesAction (conn db.DB , table , column string ) * alterReferencesAction {
852+ func NewAlterReferencesAction (conn db.DB , referencedBy map [ string ][] * schema. ReferencedBy , table , column string ) * alterReferencesAction {
851853 return & alterReferencesAction {
852- conn : conn ,
853- table : table ,
854- column : column ,
854+ conn : conn ,
855+ referencedBy : referencedBy ,
856+ table : table ,
857+ column : column ,
855858 }
856859}
857860
858861func (a * alterReferencesAction ) Execute (ctx context.Context ) error {
859- definitions , err := a .constraintDefinitions (ctx )
860- if err != nil {
861- return err
862- }
863- for _ , def := range definitions {
864- // Drop the existing constraint
865- _ , err := a .conn .ExecContext (ctx , fmt .Sprintf ("ALTER TABLE %s DROP CONSTRAINT %s" ,
866- pq .QuoteIdentifier (def .table ),
867- pq .QuoteIdentifier (def .name ),
868- ))
869- if err != nil {
870- return fmt .Errorf ("dropping constraint %s on %s: %w" , def .name , def .table , err )
871- }
872-
873- // Recreate the constraint with the table and new column
874- newDef := strings .ReplaceAll (def .def , a .column , pq .QuoteIdentifier (TemporaryName (a .column )))
875- newDef = strings .ReplaceAll (newDef , a .table , pq .QuoteIdentifier (a .table ))
876- _ , err = a .conn .ExecContext (ctx , fmt .Sprintf ("ALTER TABLE %s ADD CONSTRAINT %s %s" ,
877- pq .QuoteIdentifier (def .table ),
878- pq .QuoteIdentifier (def .name ),
879- newDef ,
880- ))
881- if err != nil {
882- return fmt .Errorf ("altering references for %s.%s: %w" , a .table , a .column , err )
883- }
884- }
885- return nil
886- }
887-
888- type constraintDefinition struct {
889- name string
890- table string
891- def string
892- }
862+ for table , constraints := range a .referencedBy {
863+ for _ , constraint := range constraints {
864+ // Drop the existing constraint
865+ _ , err := a .conn .ExecContext (ctx , fmt .Sprintf ("ALTER TABLE %s DROP CONSTRAINT %s" ,
866+ pq .QuoteIdentifier (table ),
867+ pq .QuoteIdentifier (constraint .Name ),
868+ ))
869+ if err != nil {
870+ return fmt .Errorf ("dropping constraint %s on %s: %w" , constraint .Name , table , err )
871+ }
893872
894- func (a * alterReferencesAction ) constraintDefinitions (ctx context.Context ) ([]constraintDefinition , error ) {
895- rows , err := a .conn .QueryContext (ctx , fmt .Sprintf (`
896- SELECT conname, r.conrelid::regclass, pg_catalog.pg_get_constraintdef(r.oid, true) as condef
897- FROM pg_catalog.pg_constraint r
898- WHERE confrelid = %s::regclass AND r.contype = 'f'` ,
899- pq .QuoteIdentifier (a .table ),
900- ))
901- // No FK constraint for table
902- if err != nil {
903- return nil , nil
904- }
905- defer rows .Close ()
873+ // Recreate the constraint with the table and new column
874+ newDef := strings .ReplaceAll (constraint .Definition , a .column , pq .QuoteIdentifier (TemporaryName (a .column )))
875+ newDef = strings .ReplaceAll (newDef , a .table , pq .QuoteIdentifier (a .table ))
876+ _ , err = a .conn .ExecContext (ctx , fmt .Sprintf ("ALTER TABLE %s ADD CONSTRAINT %s %s" ,
877+ pq .QuoteIdentifier (table ),
878+ pq .QuoteIdentifier (constraint .Name ),
879+ newDef ,
880+ ))
881+ if err != nil {
882+ return fmt .Errorf ("altering references for %s.%s: %w" , a .table , a .column , err )
883+ }
906884
907- defs := make ([]constraintDefinition , 0 )
908- for rows .Next () {
909- var def constraintDefinition
910- if err := rows .Scan (& def .name , & def .table , & def .def ); err != nil {
911- return nil , fmt .Errorf ("scanning referencing constraints for %s.%s: %w" , a .table , a .column , err )
912885 }
913- defs = append (defs , def )
914886 }
915- return defs , rows . Err ()
887+ return nil
916888}
0 commit comments