Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Handle nullable columns #233

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
114 changes: 114 additions & 0 deletions source/common/table_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright © 2024 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"context"
"fmt"

sdk "github.com/conduitio/conduit-connector-sdk"
"github.com/jackc/pgx/v5/pgxpool"
)

type TableInfo struct {
Name string
Columns map[string]*ColumnInfo
}

func NewTableInfo(tableName string) *TableInfo {
return &TableInfo{
Name: tableName,
Columns: make(map[string]*ColumnInfo),
}
}

type ColumnInfo struct {
IsNotNull bool
}

type TableInfoFetcher struct {
connPool *pgxpool.Pool
tableInfo map[string]*TableInfo
}

func NewTableInfoFetcher(connPool *pgxpool.Pool) *TableInfoFetcher {
return &TableInfoFetcher{
connPool: connPool,
tableInfo: make(map[string]*TableInfo),
}
}

func (i TableInfoFetcher) Refresh(ctx context.Context, tableName string) error {
tx, err := i.connPool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to start tx for getting table info: %w", err)
}
defer func() {
if err := tx.Rollback(ctx); err != nil {
sdk.Logger(ctx).Warn().
Err(err).
Msgf("error on tx rollback for getting table info")
}
}()

query := `
SELECT a.attname as column_name, a.attnotnull as is_not_null
FROM pg_catalog.pg_attribute a
WHERE a.attrelid = $1::regclass
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum;
`

rows, err := tx.Query(context.Background(), query, tableName)
if err != nil {
sdk.Logger(ctx).
Err(err).
Str("query", query).
Msgf("failed to execute table info query")

return fmt.Errorf("failed to get table info: %w", err)
}
defer rows.Close()

ti := NewTableInfo(tableName)
for rows.Next() {
var columnName string
var isNotNull bool

err := rows.Scan(&columnName, &isNotNull)
if err != nil {
return fmt.Errorf("failed to scan table info row: %w", err)
}

ci := ti.Columns[columnName]
if ci == nil {
ci = &ColumnInfo{}
ti.Columns[columnName] = ci
}
ci.IsNotNull = isNotNull
}

if err := rows.Err(); err != nil {
return fmt.Errorf("failed to get table info rows: %w", err)
}

i.tableInfo[tableName] = ti
return nil
}

func (i TableInfoFetcher) GetTable(name string) *TableInfo {
return i.tableInfo[name]
}
9 changes: 8 additions & 1 deletion source/logrepl/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"

"github.com/conduitio/conduit-commons/opencdc"
"github.com/conduitio/conduit-connector-postgres/source/common"
"github.com/conduitio/conduit-connector-postgres/source/logrepl/internal"
"github.com/conduitio/conduit-connector-postgres/source/position"
sdk "github.com/conduitio/conduit-connector-sdk"
Expand Down Expand Up @@ -65,7 +66,13 @@ func NewCDCIterator(ctx context.Context, pool *pgxpool.Pool, c CDCConfig) (*CDCI
}

records := make(chan opencdc.Record)
handler := NewCDCHandler(internal.NewRelationSet(), c.TableKeys, records, c.WithAvroSchema)
handler := NewCDCHandler(
internal.NewRelationSet(),
common.NewTableInfoFetcher(pool),
c.TableKeys,
records,
c.WithAvroSchema,
)

sub, err := internal.CreateSubscription(
ctx,
Expand Down
36 changes: 25 additions & 11 deletions source/logrepl/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/conduitio/conduit-commons/opencdc"
cschema "github.com/conduitio/conduit-commons/schema"
"github.com/conduitio/conduit-connector-postgres/source/common"
"github.com/conduitio/conduit-connector-postgres/source/logrepl/internal"
"github.com/conduitio/conduit-connector-postgres/source/position"
"github.com/conduitio/conduit-connector-postgres/source/schema"
Expand All @@ -36,16 +37,25 @@ type CDCHandler struct {
out chan<- opencdc.Record
lastTXLSN pglogrepl.LSN

tableInfo *common.TableInfoFetcher

withAvroSchema bool
keySchemas map[string]cschema.Schema
payloadSchemas map[string]cschema.Schema
}

func NewCDCHandler(rs *internal.RelationSet, tableKeys map[string]string, out chan<- opencdc.Record, withAvroSchema bool) *CDCHandler {
func NewCDCHandler(
rs *internal.RelationSet,
tableInfo *common.TableInfoFetcher,
tableKeys map[string]string,
out chan<- opencdc.Record,
withAvroSchema bool,
) *CDCHandler {
return &CDCHandler{
tableKeys: tableKeys,
relationSet: rs,
out: out,
tableInfo: tableInfo,
withAvroSchema: withAvroSchema,
keySchemas: make(map[string]cschema.Schema),
payloadSchemas: make(map[string]cschema.Schema),
Expand All @@ -64,6 +74,10 @@ func (h *CDCHandler) Handle(ctx context.Context, m pglogrepl.Message, lsn pglogr
case *pglogrepl.RelationMessage:
// We have to add the Relations to our Set so that we can decode our own output
h.relationSet.Add(m)
err := h.tableInfo.Refresh(ctx, m.RelationName)
if err != nil {
return 0, fmt.Errorf("failed to refresh table info: %w", err)
}
case *pglogrepl.InsertMessage:
if err := h.handleInsert(ctx, m, lsn); err != nil {
return 0, fmt.Errorf("logrepl handler insert: %w", err)
Expand Down Expand Up @@ -102,15 +116,15 @@ func (h *CDCHandler) handleInsert(
return fmt.Errorf("failed getting relation %v: %w", msg.RelationID, err)
}

newValues, err := h.relationSet.Values(msg.RelationID, msg.Tuple)
if err != nil {
return fmt.Errorf("failed to decode new values: %w", err)
}

if err := h.updateAvroSchema(ctx, rel); err != nil {
return fmt.Errorf("failed to update avro schema: %w", err)
}

newValues, err := h.relationSet.Values(msg.RelationID, msg.Tuple, h.tableInfo.GetTable(rel.RelationName))
if err != nil {
return fmt.Errorf("failed to decode new values: %w", err)
}

rec := sdk.Util.Source.NewRecordCreate(
h.buildPosition(lsn),
h.buildRecordMetadata(rel),
Expand All @@ -134,7 +148,7 @@ func (h *CDCHandler) handleUpdate(
return err
}

newValues, err := h.relationSet.Values(msg.RelationID, msg.NewTuple)
newValues, err := h.relationSet.Values(msg.RelationID, msg.NewTuple, h.tableInfo.GetTable(rel.RelationName))
if err != nil {
return fmt.Errorf("failed to decode new values: %w", err)
}
Expand All @@ -143,7 +157,7 @@ func (h *CDCHandler) handleUpdate(
return fmt.Errorf("failed to update avro schema: %w", err)
}

oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple)
oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple, h.tableInfo.GetTable(rel.RelationName))
if err != nil {
// this is not a critical error, old values are optional, just log it
// we use level "trace" intentionally to not clog up the logs in production
Expand Down Expand Up @@ -174,7 +188,7 @@ func (h *CDCHandler) handleDelete(
return err
}

oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple)
oldValues, err := h.relationSet.Values(msg.RelationID, msg.OldTuple, h.tableInfo.GetTable(rel.RelationName))
if err != nil {
return fmt.Errorf("failed to decode old values: %w", err)
}
Expand Down Expand Up @@ -251,7 +265,7 @@ func (h *CDCHandler) updateAvroSchema(ctx context.Context, rel *pglogrepl.Relati
return nil
}
// Payload schema
avroPayloadSch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_payload", rel)
avroPayloadSch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_payload", rel, h.tableInfo.GetTable(rel.RelationName))
if err != nil {
return fmt.Errorf("failed to extract payload schema: %w", err)
}
Expand All @@ -267,7 +281,7 @@ func (h *CDCHandler) updateAvroSchema(ctx context.Context, rel *pglogrepl.Relati
h.payloadSchemas[rel.RelationName] = ps

// Key schema
avroKeySch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_key", rel, h.tableKeys[rel.RelationName])
avroKeySch, err := schema.Avro.ExtractLogrepl(rel.RelationName+"_key", rel, h.tableInfo.GetTable(rel.RelationName), h.tableKeys[rel.RelationName])
if err != nil {
return fmt.Errorf("failed to extract key schema: %w", err)
}
Expand Down
9 changes: 5 additions & 4 deletions source/logrepl/internal/relationset.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"

"github.com/conduitio/conduit-connector-postgres/source/common"
"github.com/conduitio/conduit-connector-postgres/source/types"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5/pgtype"
Expand Down Expand Up @@ -50,7 +51,7 @@ func (rs *RelationSet) Get(id uint32) (*pglogrepl.RelationMessage, error) {
return msg, nil
}

func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData) (map[string]any, error) {
func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData, tableInfo *common.TableInfo) (map[string]any, error) {
if row == nil {
return nil, errors.New("no tuple data")
}
Expand All @@ -65,7 +66,7 @@ func (rs *RelationSet) Values(id uint32, row *pglogrepl.TupleData) (map[string]a
// assert same number of row and rel columns
for i, tuple := range row.Columns {
col := rel.Columns[i]
v, decodeErr := rs.decodeValue(col, tuple.Data)
v, decodeErr := rs.decodeValue(col, tableInfo.Columns[col.Name], tuple.Data)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode value for column %q: %w", col.Name, err)
}
Expand All @@ -84,7 +85,7 @@ func (rs *RelationSet) oidToCodec(id uint32) pgtype.Codec {
return dt.Codec
}

func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, data []byte) (any, error) {
func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, colInfo *common.ColumnInfo, data []byte) (any, error) {
decoder := rs.oidToCodec(col.DataType)
// This workaround is due to an issue in pgx v5.7.1.
// Namely, that version introduces an XML codec
Expand All @@ -105,7 +106,7 @@ func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, data []
return nil, fmt.Errorf("failed to decode value of pgtype %v: %w", col.DataType, err)
}

v, err := types.Format(col.DataType, val)
v, err := types.Format(col.DataType, val, colInfo.IsNotNull)
if err != nil {
return nil, fmt.Errorf("failed to format column %q type %T: %w", col.Name, val, err)
}
Expand Down
29 changes: 24 additions & 5 deletions source/schema/avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"slices"

"github.com/conduitio/conduit-connector-postgres/source/common"
"github.com/hamba/avro/v2"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5/pgconn"
Expand Down Expand Up @@ -65,7 +66,7 @@ type avroExtractor struct {

// ExtractLogrepl extracts an Avro schema from the given pglogrepl.RelationMessage.
// If `fieldNames` are specified, then only the given fields will be included in the schema.
func (a avroExtractor) ExtractLogrepl(schemaName string, rel *pglogrepl.RelationMessage, fieldNames ...string) (*avro.RecordSchema, error) {
func (a *avroExtractor) ExtractLogrepl(schemaName string, rel *pglogrepl.RelationMessage, tableInfo *common.TableInfo, fieldNames ...string) (*avro.RecordSchema, error) {
var fields []pgconn.FieldDescription

for i := range rel.Columns {
Expand All @@ -76,12 +77,12 @@ func (a avroExtractor) ExtractLogrepl(schemaName string, rel *pglogrepl.Relation
})
}

return a.Extract(schemaName, fields, fieldNames...)
return a.Extract(schemaName, tableInfo, fields, fieldNames...)
}

// Extract extracts an Avro schema from the given Postgres field descriptions.
// If `fieldNames` are specified, then only the given fields will be included in the schema.
func (a *avroExtractor) Extract(schemaName string, fields []pgconn.FieldDescription, fieldNames ...string) (*avro.RecordSchema, error) {
func (a *avroExtractor) Extract(schemaName string, tableInfo *common.TableInfo, fields []pgconn.FieldDescription, fieldNames ...string) (*avro.RecordSchema, error) {
var avroFields []*avro.Field

for _, f := range fields {
Expand All @@ -94,7 +95,7 @@ func (a *avroExtractor) Extract(schemaName string, fields []pgconn.FieldDescript
return nil, fmt.Errorf("field %q with OID %d cannot be resolved", f.Name, f.DataTypeOID)
}

s, err := a.extractType(t, f.TypeModifier)
s, err := a.extractType(t, f.TypeModifier, tableInfo.Columns[f.Name].IsNotNull)
if err != nil {
return nil, err
}
Expand All @@ -119,7 +120,25 @@ func (a *avroExtractor) Extract(schemaName string, fields []pgconn.FieldDescript
return sch, nil
}

func (a *avroExtractor) extractType(t *pgtype.Type, typeMod int32) (avro.Schema, error) {
func (a *avroExtractor) extractType(t *pgtype.Type, typeMod int32, notNull bool) (avro.Schema, error) {
baseType, err := a.extractBaseType(t, typeMod)
if err != nil {
return nil, err
}

if !notNull {
schema, err := avro.NewUnionSchema([]avro.Schema{avro.NewNullSchema(), baseType})
if err != nil {
return nil, fmt.Errorf("failed to create avro union schema for nullable type %v: %w", baseType, err)
}

return schema, nil
}

return baseType, nil
}

func (a *avroExtractor) extractBaseType(t *pgtype.Type, typeMod int32) (avro.Schema, error) {
if ps, ok := a.avroMap[t.Name]; ok {
return ps, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
"testing"
"time"

"github.com/conduitio/conduit-connector-postgres/source/common"
"github.com/conduitio/conduit-connector-postgres/source/cpool"
"github.com/conduitio/conduit-connector-postgres/source/types"
"github.com/conduitio/conduit-connector-postgres/test"
"github.com/hamba/avro/v2"
Expand All @@ -36,7 +38,14 @@
is := is.New(t)

c := test.ConnectSimple(ctx, t, test.RegularConnString)
connPool, err := cpool.New(ctx, test.RegularConnString)
is.NoErr(err)

table := setupAvroTestTable(ctx, t, c)
tableInfoFetcher := common.NewTableInfoFetcher(connPool)
err = tableInfoFetcher.Refresh(ctx, table)
is.NoErr(err)

insertAvroTestRow(ctx, t, c, table)

rows, err := c.Query(ctx, "SELECT * FROM "+table)
Expand All @@ -50,7 +59,7 @@

fields := rows.FieldDescriptions()

sch, err := Avro.Extract(table, fields)
sch, err := Avro.Extract(table, tableInfoFetcher.GetTable(table), fields)
is.NoErr(err)

t.Run("schema is parsable", func(t *testing.T) {
Expand Down Expand Up @@ -228,7 +237,7 @@
switch f.DataTypeOID {
case pgtype.NumericOID:
n := new(big.Rat)
n.SetString(fmt.Sprint(types.Format(0, values[i])))

Check failure on line 240 in source/schema/avro_integration_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint

not enough arguments in call to types.Format

Check failure on line 240 in source/schema/avro_integration_test.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to types.Format
row[f.Name] = n
case pgtype.UUIDOID:
row[f.Name] = fmt.Sprint(values[i])
Expand Down
Loading
Loading