diff --git a/nsqlookup/resolver.go b/nsqlookup/resolver.go index 408afec..418b904 100644 --- a/nsqlookup/resolver.go +++ b/nsqlookup/resolver.go @@ -114,8 +114,6 @@ type ConsulResolver struct { func (r *ConsulResolver) Resolve(ctx context.Context) (list []string, err error) { var address = r.Address var service = r.Service - var req *http.Request - var res *http.Response var t http.RoundTripper if t = r.Transport; t == nil { @@ -134,9 +132,75 @@ func (r *ConsulResolver) Resolve(ctx context.Context) (list []string, err error) address = "http://" + address } - if req, err = http.NewRequest("GET", address+"/v1/catalog/service/"+service, nil); err != nil { + var checksResults []struct { + Node string + } + + err = r.get(ctx, "v1/health/checks/"+service+"?passing", &checksResults) + if err != nil { return } + + var serviceResults []struct { + Node string + Address string + ServiceAddress string + ServicePort int + } + + err = r.get(ctx, "v1/catalog/service/"+service, &serviceResults) + if err != nil { + return + } + + list = make([]string, 0, len(checksResults)) + + for _, r := range serviceResults { + var passing bool + for _, c := range checksResults { + if c.Node == r.Node { + passing = true + break + } + } + + if passing { + host := r.ServiceAddress + port := r.ServicePort + + if len(host) == 0 { + host = r.Address + } + + list = append(list, net.JoinHostPort(host, strconv.Itoa(port))) + } + } + + return +} + +func (r *ConsulResolver) get(ctx context.Context, endpoint string, result interface{}) error { + var address = r.Address + var req *http.Request + var res *http.Response + var t http.RoundTripper + var err error + + if t = r.Transport; t == nil { + t = http.DefaultTransport + } + + if len(address) == 0 { + address = "http://localhost:8500" + } + + if strings.Index(address, "://") < 0 { + address = "http://" + address + } + + if req, err = http.NewRequest("GET", address+"/"+endpoint, nil); err != nil { + return err + } req.Header.Set("User-Agent", "nsqlookup consul resolver") req.Header.Set("Accept", "application/json") @@ -145,42 +209,24 @@ func (r *ConsulResolver) Resolve(ctx context.Context) (list []string, err error) } if res, err = t.RoundTrip(req); err != nil { - return + return err } defer res.Body.Close() switch res.StatusCode { case http.StatusOK: case http.StatusNotFound: - return + return err default: - err = fmt.Errorf("error looking up %s on consul agent at %s: %d %s", service, address, res.StatusCode, res.Status) - return - } - - var results []struct { - Address string - ServiceAddress string - ServicePort int + err = fmt.Errorf("error looking up %s on consul agent at %s: %d %s", endpoint, address, res.StatusCode, res.Status) + return err } - if err = json.NewDecoder(res.Body).Decode(&results); err != nil { - return + if err = json.NewDecoder(res.Body).Decode(result); err != nil { + return err } - list = make([]string, 0, len(results)) - - for _, r := range results { - host := r.ServiceAddress - port := r.ServicePort - - if len(host) == 0 { - host = r.Address - } - list = append(list, net.JoinHostPort(host, strconv.Itoa(port))) - } - - return + return nil } // MultiResolver returns a resolver that merges all resolves from rslv when its diff --git a/nsqlookup/resolver_test.go b/nsqlookup/resolver_test.go index 3264b61..3206c93 100644 --- a/nsqlookup/resolver_test.go +++ b/nsqlookup/resolver_test.go @@ -109,27 +109,46 @@ func TestResolveCached(t *testing.T) { func TestResolveConsul(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { - if req.URL.Path != "/v1/catalog/service/nsqlookupd" { + if req.URL.Path == "/v1/catalog/service/nsqlookupd" { + json.NewEncoder(res).Encode([]struct { + Node string + ServiceAddress string + ServicePort int + }{ + { + Node: "A", + ServiceAddress: "127.0.0.1", + ServicePort: 4242, + }, + { + Node: "B", + ServiceAddress: "192.168.0.1", + ServicePort: 4161, + }, + { + Node: "C", + ServiceAddress: "192.168.0.2", + ServicePort: 4161, + }, + }) + } else if req.URL.Path == "/v1/health/checks/nsqlookupd" { + json.NewEncoder(res).Encode([]struct { + Node string + }{ + { + Node: "A", + }, + { + Node: "B", + }, + { + Node: "C", + }, + }) + } else { t.Error("bad URL path:", req.URL.Path) } res.Header().Set("Content-Type", "application/json; charset=utf-8") - json.NewEncoder(res).Encode([]struct { - ServiceAddress string - ServicePort int - }{ - { - ServiceAddress: "127.0.0.1", - ServicePort: 4242, - }, - { - ServiceAddress: "192.168.0.1", - ServicePort: 4161, - }, - { - ServiceAddress: "192.168.0.2", - ServicePort: 4161, - }, - }) })) defer server.Close()