From 59af4f866439c6ce3a48010f2d37db2387231bd5 Mon Sep 17 00:00:00 2001 From: Jacob Brewer Date: Sun, 3 Nov 2024 11:32:14 +0000 Subject: [PATCH] feat(templates): Adding default templates to the binary (#43) * using default templates * Updating example --- cmd_generate.go | 24 ++- example/models/generate.sh | 2 +- example/models/templates/_delete.tmpl | 24 --- example/models/templates/_insert.tmpl | 73 -------- example/models/templates/_insert_update.tmpl | 37 ---- example/models/templates/_tags.tmpl | 13 -- example/models/templates/_type.tmpl | 169 ----------------- .../templates/_type_reverse_nullable.tmpl | 89 --------- example/models/templates/_update.tmpl | 59 ------ example/models/templates/model.tmpl | 171 ------------------ pkg/generation/templates.go | 20 ++ templates/_insert.tmpl | 4 +- templates/_insert_update.tmpl | 2 +- templates/_type.tmpl | 169 ----------------- templates/_type_reverse_nullable.tmpl | 89 --------- templates/model.tmpl | 8 +- 16 files changed, 48 insertions(+), 905 deletions(-) delete mode 100644 example/models/templates/_delete.tmpl delete mode 100644 example/models/templates/_insert.tmpl delete mode 100644 example/models/templates/_insert_update.tmpl delete mode 100644 example/models/templates/_tags.tmpl delete mode 100644 example/models/templates/_type.tmpl delete mode 100644 example/models/templates/_type_reverse_nullable.tmpl delete mode 100644 example/models/templates/_update.tmpl delete mode 100644 example/models/templates/model.tmpl delete mode 100644 templates/_type.tmpl delete mode 100644 templates/_type_reverse_nullable.tmpl diff --git a/cmd_generate.go b/cmd_generate.go index 0ece532..deafafd 100644 --- a/cmd_generate.go +++ b/cmd_generate.go @@ -2,6 +2,7 @@ package main import ( "context" + "embed" "flag" "log/slog" "path/filepath" @@ -10,6 +11,9 @@ import ( "github.com/jacobbrewer1/goschema/pkg/generation" ) +//go:embed templates/*.tmpl +var defaultTemplates embed.FS + type generateCmd struct { // templatesLocation is the location of the templates to use. templatesLocation string @@ -22,6 +26,9 @@ type generateCmd struct { // fileExtensionPrefix is the prefix to add to the generated file extension. fileExtensionPrefix string + + // defaultTemplates is whether to use the binary templates. + defaultTemplates bool } func (g *generateCmd) Name() string { @@ -43,6 +50,7 @@ func (g *generateCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&g.outputLocation, "out", ".", "The location to write the generated files to.") f.StringVar(&g.sqlLocation, "sql", "./schemas/*.sql", "The location of the SQL files to use.") f.StringVar(&g.fileExtensionPrefix, "extension", "xo", "The prefix to add to the generated file extension.") + f.BoolVar(&g.defaultTemplates, "default", true, "Whether to use the default templates.") } func (g *generateCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { @@ -69,10 +77,18 @@ func (g *generateCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface return subcommands.ExitFailure } - err = generation.RenderTemplates(tables, g.templatesLocation, g.outputLocation, g.fileExtensionPrefix) - if err != nil { - slog.Error("Error rendering templates", slog.String("templatesLocation", g.templatesLocation), slog.String("outputLocation", g.outputLocation), slog.String("error", err.Error())) - return subcommands.ExitFailure + if g.defaultTemplates { + err = generation.RenderWithTemplates(defaultTemplates, tables, g.outputLocation, g.fileExtensionPrefix) + if err != nil { + slog.Error("Error rendering default templates", slog.String("outputLocation", g.outputLocation), slog.String("error", err.Error())) + return subcommands.ExitFailure + } + } else { + err = generation.RenderTemplates(tables, g.templatesLocation, g.outputLocation, g.fileExtensionPrefix) + if err != nil { + slog.Error("Error rendering templates", slog.String("templatesLocation", g.templatesLocation), slog.String("outputLocation", g.outputLocation), slog.String("error", err.Error())) + return subcommands.ExitFailure + } } return subcommands.ExitSuccess diff --git a/example/models/generate.sh b/example/models/generate.sh index 5967603..06f17ef 100644 --- a/example/models/generate.sh +++ b/example/models/generate.sh @@ -78,7 +78,7 @@ if [ "$forced" = false ]; then fi for model in $togen; do - gum spin --spinner dot --title "Generating model $model" -- goschema generate --templates=./templates/*tmpl --out=./ --sql=./schemas/"$model".sql --extension=xo + gum spin --spinner dot --title "Generating model $model" -- goschema generate --out=./ --sql=./schemas/"$model".sql --extension=xo go fmt ./"$model".xo.go goimports -w ./"$model".xo.go done diff --git a/example/models/templates/_delete.tmpl b/example/models/templates/_delete.tmpl deleted file mode 100644 index 278789f..0000000 --- a/example/models/templates/_delete.tmpl +++ /dev/null @@ -1,24 +0,0 @@ -{{- define "delete" -}} -{{- $struct := .Name | structify -}} -// Delete deletes the {{ $struct }} from the database. -func (m *{{ $struct }}) Delete(db DB) error { - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("delete_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - {{ if identity_columns . -}} - {{ $cols := identity_columns . }} - const sqlstr = "DELETE FROM {{ .Name }} WHERE {{ range $i, $column := $cols }}{{ if $i }} AND {{ end }}`{{ $column.Name }}` = ?{{ end }}" - - DBLog(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - _, err := db.Exec(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - {{- else -}} - {{ $cols := .Columns }} - const sqlstr = "DELETE FROM {{ .Name }} WHERE {{ range $i, $column := $cols }}{{ if $i }} AND {{ end }}`{{ $column.Name }}` = ?{{ end }}" - - DBLog(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - _, err := db.Exec(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - {{- end }} - - return err -} -{{- end -}} diff --git a/example/models/templates/_insert.tmpl b/example/models/templates/_insert.tmpl deleted file mode 100644 index fd3b362..0000000 --- a/example/models/templates/_insert.tmpl +++ /dev/null @@ -1,73 +0,0 @@ -{{- define "insert" -}} -{{- $struct := .Name | structify -}} -// Insert inserts the {{ $struct }} to the database. -func (m *{{ $struct }}) Insert(db DB) error { - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - {{ $autoinc := autoinc_column . }} - {{- $cols := non_autoinc_columns . -}} - const sqlstr = "INSERT INTO {{ .Name }} (" + - "{{ range $i, $column := $cols }}{{ if $i }}, {{ end }}`{{ $column.Name }}`{{ end }}" + - ") VALUES (" + - "{{ range $i, $column := $cols }}{{ if $i }}, {{ end }}?{{ end }}" + - ")" - - DBLog(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - {{ if $autoinc }}res{{ else }}_{{ end }}, err := db.Exec(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - {{ with $autoinc -}} - if err != nil { - return err - } - - id, err := res.LastInsertId() - if err != nil { - return err - } - - m.{{ .Name | structify }} = {{ template "type" . }}(id) - return nil - {{- else -}} - return err - {{- end }} -} - -func InsertMany{{ $struct }}s(db DB, ms ...*{{ $struct }}) error { - if len(ms) == 0 { - return nil - } - - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_many_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - vals := make([]any, 0, len(ms)) - for _, m := range ms { - // Dereference the pointer to get the struct value. - vals = append(vals, []any{*m}) - } - - sqlstr, args, err := inserter.NewBatch(vals, inserter.WithTable("{{ .Name }}")).GenerateSQL() - if err != nil { - return fmt.Errorf("failed to create batch insert: %w", err) - } - - DBLog(sqlstr, args...) - {{ if $autoinc }}res{{ else }}_{{ end }}, err := db.Exec(sqlstr, args...) - if err != nil { - return err - } - - {{ with $autoinc -}} - id, err := res.LastInsertId() - if err != nil { - return err - } - - for i, m := range ms { - m.{{ .Name | structify }} = {{ template "type" . }}(id + int64(i)) - } - {{- end }} - - return nil -} -{{- end -}} diff --git a/example/models/templates/_insert_update.tmpl b/example/models/templates/_insert_update.tmpl deleted file mode 100644 index f155420..0000000 --- a/example/models/templates/_insert_update.tmpl +++ /dev/null @@ -1,37 +0,0 @@ -{{- define "insert_update" -}} -{{- $struct := .Name | structify -}} -// InsertWithUpdate inserts the {{ $struct }} to the database, and tries to update -// on unique constraint violations. -func (m *{{ $struct }}) InsertWithUpdate(db DB) error { - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_update_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - {{ $autoinc := autoinc_column . }} - {{- $cols := non_autoinc_columns . -}} - {{- $updates := non_identity_columns . -}} - const sqlstr = "INSERT INTO {{ .Name }} (" + - "{{ range $i, $column := $cols }}{{ if $i }}, {{ end }}`{{ $column.Name }}`{{ end }}" + - ") VALUES (" + - "{{ range $i, $column := $cols }}{{ if $i }}, {{ end }}?{{ end }}" + - ") ON DUPLICATE KEY UPDATE " + - "{{ range $i, $column := $updates }}{{ if $i }}, {{ end }}`{{ $column.Name }}` = VALUES(`{{ $column.Name }}`){{ end }}" - - DBLog(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - {{ if $autoinc }}res{{ else }}_{{ end }}, err := db.Exec(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - {{ with $autoinc -}} - if err != nil { - return err - } - - id, err := res.LastInsertId() - if err != nil { - return err - } - - m.{{ .Name | structify }} = {{ template "type" . }}(id) - return nil - {{- else -}} - return err - {{- end }} -} -{{- end -}} diff --git a/example/models/templates/_tags.tmpl b/example/models/templates/_tags.tmpl deleted file mode 100644 index 1833ff8..0000000 --- a/example/models/templates/_tags.tmpl +++ /dev/null @@ -1,13 +0,0 @@ -{{- define "tags" -}} - `db:"{{- .Name -}} - {{- if .AutoIncrementing -}} - ,autoinc - {{- end -}} - {{- if .InPrimaryKey -}} - ,pk - {{- end -}} - {{- if .HasDefault -}} - ,default - {{- end -}} - "` -{{- end -}} diff --git a/example/models/templates/_type.tmpl b/example/models/templates/_type.tmpl deleted file mode 100644 index eb2a662..0000000 --- a/example/models/templates/_type.tmpl +++ /dev/null @@ -1,169 +0,0 @@ -{{- define "type" -}} - {{- if eq .Type "bigint" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint64 - {{- else -}} - int64 - {{- end -}} - {{- end -}} - {{- else if eq .Type "int" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint - {{- else -}} - int - {{- end -}} - {{- end -}} - {{- else if eq .Type "tinyint" -}} - {{- if eq .TypeSize 1 -}} - {{- if .Nullable -}} - usql.NullBool - {{- else -}} - bool - {{- end -}} - {{- else -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint8 - {{- else -}} - int8 - {{- end -}} - {{- end -}} - {{- end -}} - {{- else if eq .Type "smallint" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint16 - {{- else -}} - int16 - {{- end -}} - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint32 - {{- else -}} - int32 - {{- end -}} - {{- end -}} - {{- else if eq .Type "float" -}} - {{- if .Nullable -}} - usql.NullFloat64 - {{- else -}} - float{{ if lt .TypeSize 25 }}32{{ else }}64{{ end }} - {{- end -}} - {{- else if eq .Type "decimal" -}} - {{- if .Nullable -}} - usql.NullFloat64 - {{- else -}} - float64 - {{- end -}} - {{- else if eq .Type "double" -}} - {{- if .Nullable -}} - usql.NullFloat64 - {{- else -}} - float64 - {{- end -}} - {{- else if eq .Type "char" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "binary" -}} - []byte - {{- else if eq .Type "varchar" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "varbinary" -}} - []byte - {{- else if eq .Type "tinyblob" -}} - []byte - {{- else if eq .Type "tinytext" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "blob" -}} - []byte - {{- else if eq .Type "text" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "mediumblob" -}} - []byte - {{- else if eq .Type "mediumtext" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "longblob" -}} - []byte - {{- else if eq .Type "longtext" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "enum" -}} - {{- if .Nullable -}} - usql.NullEnum - {{- else -}} - usql.Enum - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "year" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- else if eq .Type "date" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- else if eq .Type "time" -}} - {{- if .Nullable -}} - usql.NullDuration - {{- else -}} - usql.Duration - {{- end -}} - {{- else if eq .Type "datetime" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- else if eq .Type "timestamp" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- end -}} -{{- end -}} \ No newline at end of file diff --git a/example/models/templates/_type_reverse_nullable.tmpl b/example/models/templates/_type_reverse_nullable.tmpl deleted file mode 100644 index d65f9c8..0000000 --- a/example/models/templates/_type_reverse_nullable.tmpl +++ /dev/null @@ -1,89 +0,0 @@ -{{- define "type_reverse_nullable" -}} - {{- if eq .Type "bigint" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "int" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "tinyint" -}} - {{- if eq .TypeSize 1 -}} - {{- if .Nullable -}} - Bool - {{- end -}} - {{- end -}} - {{- else if eq .Type "smallint" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "float" -}} - {{- if .Nullable -}} - Float64 - {{- end -}} - {{- else if eq .Type "decimal" -}} - {{- if .Nullable -}} - Float64 - {{- end -}} - {{- else if eq .Type "double" -}} - {{- if .Nullable -}} - Float64 - {{- end -}} - {{- else if eq .Type "char" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "varchar" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "tinytext" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "text" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "mediumtext" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "longtext" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "enum" -}} - {{- if .Nullable -}} - Enum - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "year" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- else if eq .Type "date" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- else if eq .Type "time" -}} - {{- if .Nullable -}} - Duration - {{- end -}} - {{- else if eq .Type "datetime" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- else if eq .Type "timestamp" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- end -}} -{{- end -}} \ No newline at end of file diff --git a/example/models/templates/_update.tmpl b/example/models/templates/_update.tmpl deleted file mode 100644 index cb22ebe..0000000 --- a/example/models/templates/_update.tmpl +++ /dev/null @@ -1,59 +0,0 @@ -{{- define "update" -}} -{{- $struct := .Name | structify -}} -// Update updates the {{ $struct }} in the database. -func (m *{{ $struct }}) Update(db DB) error { - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("update_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - {{ $cols := non_identity_columns . -}} - {{- $wheres := identity_columns . -}} - const sqlstr = "UPDATE {{ .Name }} " + - "SET {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}`{{ $column.Name }}` = ?{{ end }} " + - "WHERE {{ range $i, $column := $wheres }}{{ if $i }} AND {{ end }}`{{ $column.Name }}` = ?{{ end }}" - - DBLog(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}, {{ range $i, $column := $wheres }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - res, err := db.Exec(sqlstr, {{ range $i, $column := $cols }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}, {{ range $i, $column := $wheres }}{{ if $i }}, {{ end }}m.{{ $column.Name | structify }}{{ end }}) - if err != nil { - return err - } - - // Requires clientFoundRows=true - if i, err := res.RowsAffected(); err != nil { - return err - } else if i <= 0 { - return ErrNoAffectedRows - } - - return nil -} - -func (m *{{ $struct }}) Patch(db DB, newT *{{ $struct }}) error { - if newT == nil { - return errors.New("new {{ .Name }} is nil") - } - - res, err := patcher.NewDiffSQLPatch(m, newT, patcher.WithTable("{{ .Name }}")) - if err != nil { - return fmt.Errorf("new diff sql patch: %w", err) - } - - sqlstr, args, err := res.GenerateSQL() - if err != nil { - switch { - case errors.Is(err, patcher.ErrNoChanges): - return nil - default: - return fmt.Errorf("failed to create patch: %w", err) - } - } - - DBLog(sqlstr, args...) - _, err = db.Exec(sqlstr, args...) - if err != nil { - return fmt.Errorf("failed to execute patch: %w", err) - } - - return nil -} - -{{- end -}} \ No newline at end of file diff --git a/example/models/templates/model.tmpl b/example/models/templates/model.tmpl deleted file mode 100644 index c9fcf2e..0000000 --- a/example/models/templates/model.tmpl +++ /dev/null @@ -1,171 +0,0 @@ -// Package models contains the database interaction model code -// -// GENERATED BY GOSCHEMA. DO NOT EDIT. -package {{ .OutputDir | base | snakecase }} - -import ( - "database/sql" - "errors" - "time" - - "github.com/go-sql-driver/mysql" - "github.com/jacobbrewer1/patcher" - "github.com/jacobbrewer1/patcher/inserter" - "github.com/jacobbrewer1/goschema/pkg/usql" - "github.com/prometheus/client_golang/prometheus" -) -{{ with .Table }} -{{ $struct := .Name | structify }} -// {{ $struct }} represents a row from '{{ .Name }}'. -{{- if .Comment }} -// {{ .Comment }} -{{- end }} -type {{ $struct }} struct { - {{ range $column := .Columns -}} - {{ .Name | structify }} {{ template "type" $column }} {{ template "tags" $column }} {{ if .Comment }}// {{ .Comment }}{{ end }} - {{ end -}} -} - -{{ template "insert" . }} - -{{ if has_primary_key . -}} -// IsPrimaryKeySet returns true if all primary key fields are set to none zero values -func (m *{{ $struct }}) IsPrimaryKeySet() bool { - {{ $length := len .PrimaryKey.Columns -}} - {{ if eq $length 1 -}} - {{ $column := index .PrimaryKey.Columns 0 -}} - return IsKeySet(m.{{ $column.Name | structify }}) - {{ else -}} - return {{ range $i, $column := .PrimaryKey.Columns -}}{{ if $i }} && {{ end }}IsKeySet(m.{{ $column.Name | structify }}){{ end }} - {{ end -}} -} -{{- end }} - -{{ if identity_columns . -}} -{{ if non_identity_columns . -}} -{{ template "update" . }} - -{{ template "insert_update" . }} -{{- end }} -{{- end }} - -// Save saves the {{ $struct }} to the database. -func (m *{{ $struct }}) Save(db DB) error { - {{ if identity_columns . -}} - {{ if non_identity_columns . -}} - if m.IsPrimaryKeySet() { - return m.Update(db) - } - {{ end -}} - {{ end -}} - return m.Insert(db) -} - -{{ if identity_columns . -}} -{{ if non_identity_columns . -}} -// SaveOrUpdate saves the {{ $struct }} to the database, but tries to update -// on unique constraint violations. -func (m *{{ $struct }}) SaveOrUpdate(db DB) error { - {{ if identity_columns . -}} - {{ if non_identity_columns . -}} - if m.IsPrimaryKeySet() { - return m.Update(db) - } - {{ end -}} - {{ end -}} - return m.InsertWithUpdate(db) -} -{{- end }} -{{- end }} - -{{ template "delete" .}} - -{{ range $key := unique_column_keys . }} -{{ $key_cnt := len $key.Columns }} -{{ $tbl_cnt := len $.Table.Columns }} -{{ if ne $key_cnt $tbl_cnt }} -{{ if eq $key.Type "primary" -}} -// {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }} retrieves a row from '{{ $.Table.Name }}' as a {{ $struct }}. -// -// Generated from primary key. -func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }}(db DB, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }} {{ template "type" $col}}{{ end }}) (*{{ $struct }}, error) { - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - const sqlstr = "SELECT {{ range $i, $column := $.Table.Columns }}{{ if $i }}, {{ end }}`{{ $column.Name }}`{{ end }} " + - "FROM {{ $.Table.Name }} " + - "WHERE {{ range $i, $col := $key.Columns }}{{ if $i }} AND {{ end }}`{{ $col.Name }}` = ?{{ end }}" - - DBLog(sqlstr, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }}{{ end }}) - var m {{ $struct }} - if err := db.Get(&m, sqlstr, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }}{{ end }}); err != nil { - return nil, err - } - - return &m, nil -} -{{ range $constraint := $.Table.Constraints -}} -{{ $constraint_ref_len := len $constraint.References }} -{{ if eq $constraint_ref_len 1 -}} -{{ $foreign_struct := $constraint.ReferenceTable | structify }} -{{ range $i, $col_data := $.Table.Columns -}} -{{ range $local_col, $foreign_col := $constraint.References -}} -{{ if eq $col_data.Name $local_col -}} -// Get{{ $local_col | structify }}{{ $foreign_struct }} Gets an instance of {{ $foreign_struct }} -// -// Generated from constraint {{ $constraint.Name }} -func (m *{{ $struct }}) Get{{ $local_col | structify }}{{ $foreign_struct }}(db DB) (*{{ $foreign_struct }}, error) { - {{ if $col_data.Nullable -}} - if !m.{{ $local_col | structify }}.Valid { - return nil, nil - } - - {{ end -}} - return {{ $foreign_struct }}By{{ $foreign_col | structify }}(db, m.{{ $local_col | structify }}{{ if $col_data.Nullable }}.{{ template "type_reverse_nullable" $col_data }}{{ end }}) -} -{{ end -}} -{{ end -}} -{{ end -}} -{{- end }} -{{- end }} - -{{- else -}} -{{- $uniq := contains "unique" $key.Type }} -{{- if $uniq }} -// {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }} retrieves {{ if $uniq }}a row{{ else }}rows{{ end }} from '{{ $.Table.Name }}' as a {{ if $uniq }}*{{ $struct }}{{ else }}[]*{{ $struct }}{{ end }}. -// -// Generated from index '{{ $key.Name }}' of type '{{ $key.Type }}'. -func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }}(db DB, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }} {{ template "type" $col}}{{ end }}) ({{ if not $uniq }}[]{{ end }}*{{ $struct }}, error) { - t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_{{ $struct | structify -}}")) - defer t.ObserveDuration() - - const sqlstr = "SELECT {{ range $i, $column := $.Table.Columns }}{{ if $i }}, {{ end }}`{{ $column.Name }}`{{ end }} " + - "FROM {{ $.Table.Name }} " + - "WHERE {{ range $i, $col := $key.Columns }}{{ if $i }} AND {{ end }}`{{ $col.Name }}` = ?{{ end }}" - - DBLog(sqlstr, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }}{{ end }}) - var m {{ if not $uniq }}[]*{{ end }}{{ $struct }} - if err := db.{{ if $uniq }}Get{{ else }}Select{{ end }}(&m, sqlstr, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }}{{ end }}); err != nil { - return nil, err - } - - return {{ if $uniq }}&{{ end }}m, nil -} -{{ end }} -{{ end }} -{{ end }} -{{ end }} - -{{- range $enumcol := enum_columns . }} -// Valid values for the '{{ $enumcol.Name | structify }}' enum column -var ( -{{- range $enum := .Elements }} - {{ $struct }}{{ $enumcol.Name | structify }}{{ $enum | structify }} = {{ if $enumcol.Nullable }}msql.NewNullEnum("{{ $enum }}"){{ else }}"{{ $enum }}"{{ end }} -{{- end }} -{{- if $enumcol.Nullable }} - {{ $struct }}{{ $enumcol.Name | structify }}Null = msql.NullEnum{} -{{- end }} -) -{{ end }} - -{{ end }} diff --git a/pkg/generation/templates.go b/pkg/generation/templates.go index 3e925e0..2a9d05c 100644 --- a/pkg/generation/templates.go +++ b/pkg/generation/templates.go @@ -1,6 +1,7 @@ package generation import ( + "embed" "fmt" "log/slog" "os" @@ -35,6 +36,25 @@ func RenderTemplates(tables []*models.Table, templatesLoc, outputLoc string, fil return nil } +// RenderWithTemplates renders templates that are provided as embedded files +func RenderWithTemplates(fs embed.FS, tables []*models.Table, outputLoc string, fileExtensionPrefix string) error { + tmpl, err := template.New("model.tmpl").Funcs(sprig.TxtFuncMap()).Funcs(Helpers).ParseFS(fs, "templates/*.tmpl") + if err != nil { + return fmt.Errorf("error parsing templates: %w", err) + } + + for _, t := range tables { + if err := generate(&templateInfo{ + OutputDir: outputLoc, + Table: t, + }, tmpl, outputLoc, fileExtensionPrefix); err != nil { + return fmt.Errorf("error generating template: %w", err) + } + } + + return nil +} + func generate(t *templateInfo, tmpl *template.Template, outputLoc string, fileExtensionPrefix string) error { ext := ".go" if fileExtensionPrefix != "" { diff --git a/templates/_insert.tmpl b/templates/_insert.tmpl index fd3b362..0346a29 100644 --- a/templates/_insert.tmpl +++ b/templates/_insert.tmpl @@ -25,7 +25,7 @@ func (m *{{ $struct }}) Insert(db DB) error { return err } - m.{{ .Name | structify }} = {{ template "type" . }}(id) + m.{{ .Name | structify }} = {{ get_type . }}(id) return nil {{- else -}} return err @@ -64,7 +64,7 @@ func InsertMany{{ $struct }}s(db DB, ms ...*{{ $struct }}) error { } for i, m := range ms { - m.{{ .Name | structify }} = {{ template "type" . }}(id + int64(i)) + m.{{ .Name | structify }} = {{ get_type . }}(id + int64(i)) } {{- end }} diff --git a/templates/_insert_update.tmpl b/templates/_insert_update.tmpl index f155420..4dab1e5 100644 --- a/templates/_insert_update.tmpl +++ b/templates/_insert_update.tmpl @@ -28,7 +28,7 @@ func (m *{{ $struct }}) InsertWithUpdate(db DB) error { return err } - m.{{ .Name | structify }} = {{ template "type" . }}(id) + m.{{ .Name | structify }} = {{ get_type . }}(id) return nil {{- else -}} return err diff --git a/templates/_type.tmpl b/templates/_type.tmpl deleted file mode 100644 index eb2a662..0000000 --- a/templates/_type.tmpl +++ /dev/null @@ -1,169 +0,0 @@ -{{- define "type" -}} - {{- if eq .Type "bigint" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint64 - {{- else -}} - int64 - {{- end -}} - {{- end -}} - {{- else if eq .Type "int" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint - {{- else -}} - int - {{- end -}} - {{- end -}} - {{- else if eq .Type "tinyint" -}} - {{- if eq .TypeSize 1 -}} - {{- if .Nullable -}} - usql.NullBool - {{- else -}} - bool - {{- end -}} - {{- else -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint8 - {{- else -}} - int8 - {{- end -}} - {{- end -}} - {{- end -}} - {{- else if eq .Type "smallint" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint16 - {{- else -}} - int16 - {{- end -}} - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - usql.NullInt64 - {{- else -}} - {{- if .Unsigned -}} - uint32 - {{- else -}} - int32 - {{- end -}} - {{- end -}} - {{- else if eq .Type "float" -}} - {{- if .Nullable -}} - usql.NullFloat64 - {{- else -}} - float{{ if lt .TypeSize 25 }}32{{ else }}64{{ end }} - {{- end -}} - {{- else if eq .Type "decimal" -}} - {{- if .Nullable -}} - usql.NullFloat64 - {{- else -}} - float64 - {{- end -}} - {{- else if eq .Type "double" -}} - {{- if .Nullable -}} - usql.NullFloat64 - {{- else -}} - float64 - {{- end -}} - {{- else if eq .Type "char" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "binary" -}} - []byte - {{- else if eq .Type "varchar" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "varbinary" -}} - []byte - {{- else if eq .Type "tinyblob" -}} - []byte - {{- else if eq .Type "tinytext" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "blob" -}} - []byte - {{- else if eq .Type "text" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "mediumblob" -}} - []byte - {{- else if eq .Type "mediumtext" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "longblob" -}} - []byte - {{- else if eq .Type "longtext" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "enum" -}} - {{- if .Nullable -}} - usql.NullEnum - {{- else -}} - usql.Enum - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - usql.NullString - {{- else -}} - string - {{- end -}} - {{- else if eq .Type "year" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- else if eq .Type "date" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- else if eq .Type "time" -}} - {{- if .Nullable -}} - usql.NullDuration - {{- else -}} - usql.Duration - {{- end -}} - {{- else if eq .Type "datetime" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- else if eq .Type "timestamp" -}} - {{- if .Nullable -}} - usql.NullTime - {{- else -}} - time.Time - {{- end -}} - {{- end -}} -{{- end -}} \ No newline at end of file diff --git a/templates/_type_reverse_nullable.tmpl b/templates/_type_reverse_nullable.tmpl deleted file mode 100644 index d65f9c8..0000000 --- a/templates/_type_reverse_nullable.tmpl +++ /dev/null @@ -1,89 +0,0 @@ -{{- define "type_reverse_nullable" -}} - {{- if eq .Type "bigint" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "int" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "tinyint" -}} - {{- if eq .TypeSize 1 -}} - {{- if .Nullable -}} - Bool - {{- end -}} - {{- end -}} - {{- else if eq .Type "smallint" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - Int64 - {{- end -}} - {{- else if eq .Type "float" -}} - {{- if .Nullable -}} - Float64 - {{- end -}} - {{- else if eq .Type "decimal" -}} - {{- if .Nullable -}} - Float64 - {{- end -}} - {{- else if eq .Type "double" -}} - {{- if .Nullable -}} - Float64 - {{- end -}} - {{- else if eq .Type "char" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "varchar" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "tinytext" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "text" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "mediumtext" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "longtext" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "enum" -}} - {{- if .Nullable -}} - Enum - {{- end -}} - {{- else if eq .Type "mediumint" -}} - {{- if .Nullable -}} - String - {{- end -}} - {{- else if eq .Type "year" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- else if eq .Type "date" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- else if eq .Type "time" -}} - {{- if .Nullable -}} - Duration - {{- end -}} - {{- else if eq .Type "datetime" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- else if eq .Type "timestamp" -}} - {{- if .Nullable -}} - mysql.NullTime - {{- end -}} - {{- end -}} -{{- end -}} \ No newline at end of file diff --git a/templates/model.tmpl b/templates/model.tmpl index c9fcf2e..ad6b727 100644 --- a/templates/model.tmpl +++ b/templates/model.tmpl @@ -22,7 +22,7 @@ import ( {{- end }} type {{ $struct }} struct { {{ range $column := .Columns -}} - {{ .Name | structify }} {{ template "type" $column }} {{ template "tags" $column }} {{ if .Comment }}// {{ .Comment }}{{ end }} + {{ .Name | structify }} {{ get_type $column }} {{ template "tags" $column }} {{ if .Comment }}// {{ .Comment }}{{ end }} {{ end -}} } @@ -88,7 +88,7 @@ func (m *{{ $struct }}) SaveOrUpdate(db DB) error { // {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }} retrieves a row from '{{ $.Table.Name }}' as a {{ $struct }}. // // Generated from primary key. -func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }}(db DB, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }} {{ template "type" $col}}{{ end }}) (*{{ $struct }}, error) { +func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }}(db DB, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }} {{ get_type $col}}{{ end }}) (*{{ $struct }}, error) { t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_{{ $struct | structify -}}")) defer t.ObserveDuration() @@ -121,7 +121,7 @@ func (m *{{ $struct }}) Get{{ $local_col | structify }}{{ $foreign_struct }}(db } {{ end -}} - return {{ $foreign_struct }}By{{ $foreign_col | structify }}(db, m.{{ $local_col | structify }}{{ if $col_data.Nullable }}.{{ template "type_reverse_nullable" $col_data }}{{ end }}) + return {{ $foreign_struct }}By{{ $foreign_col | structify }}(db, m.{{ $local_col | structify }}{{ if $col_data.Nullable }}.{{ get_type $col_data }}{{ end }}) } {{ end -}} {{ end -}} @@ -135,7 +135,7 @@ func (m *{{ $struct }}) Get{{ $local_col | structify }}{{ $foreign_struct }}(db // {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }} retrieves {{ if $uniq }}a row{{ else }}rows{{ end }} from '{{ $.Table.Name }}' as a {{ if $uniq }}*{{ $struct }}{{ else }}[]*{{ $struct }}{{ end }}. // // Generated from index '{{ $key.Name }}' of type '{{ $key.Type }}'. -func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }}(db DB, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }} {{ template "type" $col}}{{ end }}) ({{ if not $uniq }}[]{{ end }}*{{ $struct }}, error) { +func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify }}{{ end }}(db DB, {{ range $i, $col := $key.Columns }}{{ if $i }}, {{ end }}{{ $col.Name | structify | lcfirst }} {{ get_type $col}}{{ end }}) ({{ if not $uniq }}[]{{ end }}*{{ $struct }}, error) { t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("insert_{{ $struct | structify -}}")) defer t.ObserveDuration()