Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions pkg/cmd/credential/update/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type Options struct {
Desc string
PluginsJSON string
Labels []string
DescSet bool
PluginsSet bool
LabelsSet bool
}

func NewCmd(f *cmd.Factory) *cobra.Command {
Expand All @@ -41,6 +44,9 @@ func NewCmd(f *cmd.Factory) *cobra.Command {
opts.Output, _ = c.Flags().GetString("output")
opts.GatewayGroup, _ = c.Flags().GetString("gateway-group")
opts.Consumer, _ = c.Flags().GetString("consumer")
opts.DescSet = c.Flags().Changed("desc")
opts.PluginsSet = c.Flags().Changed("plugins-json")
opts.LabelsSet = c.Flags().Changed("labels")
return actionRun(opts)
},
}
Expand Down Expand Up @@ -97,30 +103,44 @@ func actionRun(opts *Options) error {
}

pl := make(map[string]interface{})
if opts.PluginsJSON != "" {
if opts.PluginsSet {
if err := json.Unmarshal([]byte(opts.PluginsJSON), &pl); err != nil {
return fmt.Errorf("invalid --plugins-json: %w", err)
}
}

labels := make(map[string]string)
for _, label := range opts.Labels {
parts := strings.SplitN(label, "=", 2)
if len(parts) != 2 || parts[0] == "" {
return fmt.Errorf("invalid label %q, expected key=value", label)
if opts.LabelsSet {
for _, label := range opts.Labels {
parts := strings.SplitN(label, "=", 2)
if len(parts) != 2 || parts[0] == "" {
return fmt.Errorf("invalid label %q, expected key=value", label)
}
labels[parts[0]] = parts[1]
}
labels[parts[0]] = parts[1]
}

bodyReq := api.Credential{ID: opts.ID, Desc: opts.Desc}
if len(pl) > 0 {
client := api.NewClient(httpClient, cfg.BaseURL())
currentBody, err := client.Get(path, nil)
if err != nil {
return fmt.Errorf("%s", cmdutil.FormatAPIError(err))
}

var bodyReq api.Credential
if err := json.Unmarshal(currentBody, &bodyReq); err != nil {
return fmt.Errorf("failed to decode current credential: %w", err)
}
bodyReq.ID = opts.ID
if opts.DescSet {
bodyReq.Desc = opts.Desc
}
if opts.PluginsSet {
bodyReq.Plugins = pl
}
if len(labels) > 0 {
if opts.LabelsSet {
bodyReq.Labels = labels
}

client := api.NewClient(httpClient, cfg.BaseURL())
body, err := client.Put(path, bodyReq)
if err != nil {
return fmt.Errorf("%s", cmdutil.FormatAPIError(err))
Expand Down
32 changes: 30 additions & 2 deletions pkg/cmd/credential/update/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package update

import (
"encoding/json"
"io"
"net/http"
"strings"
"testing"
Expand Down Expand Up @@ -34,11 +35,37 @@ func (m *mockConfig) Save() error { return n
func TestUpdateCredential_JSONOutput(t *testing.T) {
ios, _, out, _ := iostreams.Test()
registry := &httpmock.Registry{}
registry.Register(http.MethodPut, "/apisix/admin/consumers/alice/credentials/cred1", httpmock.JSONResponse(`{"id":"cred1","desc":"updated"}`))
registry.Register(http.MethodGet, "/apisix/admin/consumers/alice/credentials/cred1", httpmock.JSONResponse(`{
"id":"cred1",
"name":"apikey",
"desc":"old",
"plugins":{"key-auth":{"key":"old-key"}},
"labels":{"env":"dev"}
}`))
registry.RegisterResponder(http.MethodPut, "/apisix/admin/consumers/alice/credentials/cred1", func(r *http.Request) (httpmock.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return httpmock.Response{}, err
}
var payload api.Credential
if err := json.Unmarshal(body, &payload); err != nil {
return httpmock.Response{}, err
}
if payload.Name != "apikey" {
t.Fatalf("expected existing credential name to be preserved, got %q", payload.Name)
}
if payload.Plugins["key-auth"] == nil {
t.Fatalf("expected existing plugins to be preserved: %+v", payload.Plugins)
}
if payload.Desc != "updated" {
t.Fatalf("expected desc to be updated, got %q", payload.Desc)
}
return httpmock.JSONResponse(`{"id":"cred1","name":"apikey","desc":"updated","plugins":{"key-auth":{"key":"old-key"}}}`), nil
})

opts := &Options{IO: ios, Client: func() (*http.Client, error) { return registry.GetClient(), nil }, Config: func() (config.Config, error) {
return &mockConfig{baseURL: "http://api.local", gatewayGroup: "gg1"}, nil
}, Consumer: "alice", ID: "cred1", GatewayGroup: "gg1", Desc: "updated"}
}, Consumer: "alice", ID: "cred1", GatewayGroup: "gg1", Desc: "updated", DescSet: true}

if err := actionRun(opts); err != nil {
t.Fatalf("actionRun failed: %v", err)
Expand Down Expand Up @@ -82,6 +109,7 @@ func TestUpdateCredential_MissingConsumer(t *testing.T) {
func TestUpdateCredential_APIError(t *testing.T) {
ios, _, _, _ := iostreams.Test()
registry := &httpmock.Registry{}
registry.Register(http.MethodGet, "/apisix/admin/consumers/alice/credentials/cred1", httpmock.JSONResponse(`{"id":"cred1","name":"apikey","plugins":{"key-auth":{}}}`))
registry.Register(http.MethodPut, "/apisix/admin/consumers/alice/credentials/cred1", httpmock.StringResponse(http.StatusInternalServerError, `{"message":"boom"}`))

opts := &Options{IO: ios, Client: func() (*http.Client, error) { return registry.GetClient(), nil }, Config: func() (config.Config, error) {
Expand Down
71 changes: 50 additions & 21 deletions pkg/cmd/gateway-group/update/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ import (
)

type Options struct {
IO *iostreams.IOStreams
Client func() (*http.Client, error)
Config func() (config.Config, error)
Output string
File string
ID string
Name string
Description string
Labels []string
Prefix string
IO *iostreams.IOStreams
Client func() (*http.Client, error)
Config func() (config.Config, error)
Output string
File string
ID string
Name string
Description string
Labels []string
Prefix string
NameSet bool
DescriptionSet bool
LabelsSet bool
PrefixSet bool
}

func NewCmd(f *cmd.Factory) *cobra.Command {
Expand All @@ -42,6 +46,10 @@ func NewCmd(f *cmd.Factory) *cobra.Command {
RunE: func(c *cobra.Command, args []string) error {
opts.ID = args[0]
opts.Output, _ = c.Flags().GetString("output")
opts.NameSet = c.Flags().Changed("name")
opts.DescriptionSet = c.Flags().Changed("description")
opts.LabelsSet = c.Flags().Changed("labels")
opts.PrefixSet = c.Flags().Changed("prefix")
return updateRun(opts)
},
}
Expand Down Expand Up @@ -84,29 +92,50 @@ func updateRun(opts *Options) error {
}

labels := map[string]string{}
for _, item := range opts.Labels {
key, value, found := strings.Cut(item, "=")
if !found || key == "" {
return &cmdutil.FlagError{Err: fmt.Errorf("invalid --labels value %q, expected key=value", item)}
if opts.LabelsSet {
for _, item := range opts.Labels {
key, value, found := strings.Cut(item, "=")
if !found || key == "" {
return &cmdutil.FlagError{Err: fmt.Errorf("invalid --labels value %q, expected key=value", item)}
}
labels[key] = value
}
labels[key] = value
}

request := map[string]interface{}{}
if opts.Name != "" {
client := api.NewClient(httpClient, cfg.BaseURL())
currentBody, err := client.Get(fmt.Sprintf("/api/gateway_groups/%s", opts.ID), nil)
if err != nil {
return fmt.Errorf("%s", cmdutil.FormatAPIError(err))
}

var current api.GatewayGroup
if err := json.Unmarshal(currentBody, &current); err != nil {
return fmt.Errorf("failed to decode current gateway group: %w", err)
}

request := map[string]interface{}{
"name": current.Name,
"description": current.Description,
}
if current.Prefix != "" {
request["prefix"] = current.Prefix
}
if current.Labels != nil {
request["labels"] = current.Labels
}
if opts.NameSet {
request["name"] = opts.Name
}
if opts.Description != "" {
if opts.DescriptionSet {
request["description"] = opts.Description
}
if opts.Prefix != "" {
if opts.PrefixSet {
request["prefix"] = opts.Prefix
}
if len(opts.Labels) > 0 {
if opts.LabelsSet {
request["labels"] = labels
}

client := api.NewClient(httpClient, cfg.BaseURL())
body, err := client.Put(fmt.Sprintf("/api/gateway_groups/%s", opts.ID), request)
if err != nil {
return fmt.Errorf("failed to update gateway group: %s", cmdutil.FormatAPIError(err))
Expand Down
102 changes: 102 additions & 0 deletions pkg/cmd/gateway-group/update/update_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package update

import (
"encoding/json"
"io"
"net/http"
"testing"

"github.com/api7/a7/internal/config"
"github.com/api7/a7/pkg/api"
"github.com/api7/a7/pkg/httpmock"
"github.com/api7/a7/pkg/iostreams"
)

type mockConfig struct{}

func (m *mockConfig) BaseURL() string { return "" }
func (m *mockConfig) Token() string { return "" }
func (m *mockConfig) GatewayGroup() string { return "" }
func (m *mockConfig) TLSSkipVerify() bool { return false }
func (m *mockConfig) CACert() string { return "" }
func (m *mockConfig) CurrentContext() string { return "" }
func (m *mockConfig) Contexts() []config.Context { return nil }
func (m *mockConfig) GetContext(name string) (*config.Context, error) { return nil, nil }
func (m *mockConfig) AddContext(ctx config.Context) error { return nil }
func (m *mockConfig) RemoveContext(name string) error { return nil }
func (m *mockConfig) SetCurrentContext(name string) error { return nil }
func (m *mockConfig) Save() error { return nil }

func TestUpdateGatewayGroup_PreservesRequiredFields(t *testing.T) {
ios, _, out, _ := iostreams.Test()
registry := &httpmock.Registry{}
registry.Register(http.MethodGet, "/api/gateway_groups/gg1", httpmock.JSONResponse(`{
"id":"gg1",
"name":"default",
"description":"old description",
"prefix":"/old",
"labels":{"env":"old"},
"status":1
}`))
registry.RegisterResponder(http.MethodPut, "/api/gateway_groups/gg1", func(r *http.Request) (httpmock.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return httpmock.Response{}, err
}
var raw map[string]interface{}
if err := json.Unmarshal(body, &raw); err != nil {
return httpmock.Response{}, err
}
if _, ok := raw["status"]; ok {
t.Fatalf("request must not include response-only field status: %s", string(body))
}
if _, ok := raw["created_at"]; ok {
t.Fatalf("request must not include response-only field created_at: %s", string(body))
}
if _, ok := raw["updated_at"]; ok {
t.Fatalf("request must not include response-only field updated_at: %s", string(body))
}

var payload api.GatewayGroup
if err := json.Unmarshal(body, &payload); err != nil {
return httpmock.Response{}, err
}
if payload.Name != "default" {
t.Fatalf("expected existing name to be preserved, got %q", payload.Name)
}
if payload.Description != "new description" {
t.Fatalf("expected description to be updated, got %q", payload.Description)
}
if payload.Prefix != "/old" {
t.Fatalf("expected existing prefix to be preserved, got %q", payload.Prefix)
}
if payload.Labels["env"] != "new" {
t.Fatalf("expected labels to be replaced from flags, got %+v", payload.Labels)
}
return httpmock.JSONResponse(`{"id":"gg1","name":"default","description":"new description","prefix":"/old","labels":{"env":"new"},"status":1}`), nil
})

opts := &Options{
IO: ios,
Client: func() (*http.Client, error) { return registry.GetClient(), nil },
Config: func() (config.Config, error) { return &mockConfig{}, nil },
Output: "json",
ID: "gg1",
Description: "new description",
DescriptionSet: true,
Labels: []string{"env=new"},
LabelsSet: true,
}
if err := updateRun(opts); err != nil {
t.Fatalf("updateRun failed: %v", err)
}

var item api.GatewayGroup
if err := json.Unmarshal(out.Bytes(), &item); err != nil {
t.Fatalf("failed to parse output: %v", err)
}
if item.Name != "default" || item.Description != "new description" {
t.Fatalf("unexpected gateway group output: %+v", item)
}
registry.Verify(t)
}
Loading
Loading