Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce handshake to client and gRPC server #42

Merged
merged 8 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions pkg/function/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"context"
"fmt"
"io"
"log"

functionpb "github.com/numaproj/numaflow-go/pkg/apis/proto/function/v1"
"github.com/numaproj/numaflow-go/pkg/function"
"github.com/numaproj/numaflow-go/pkg/info"
infoclient "github.com/numaproj/numaflow-go/pkg/info/client"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
Expand All @@ -21,16 +24,27 @@ type client struct {

// New creates a new client object.
func New(inputOptions ...Option) (*client, error) {

var opts = &options{
sockAddr: function.Addr,
maxMessageSize: function.DefaultMaxMessageSize,
sockAddr: function.Addr,
maxMessageSize: function.DefaultMaxMessageSize,
infoSvrSockAddr: info.SocketAddress,
}

for _, inputOption := range inputOptions {
inputOption(opts)
}

infoClient := infoclient.NewInfoClient(infoclient.WithSocketAddress(opts.infoSvrSockAddr))
serverInfo, err := infoClient.GetServerInfo(context.Background())
if err != nil {
// TODO: return nil, err
log.Println("Failed to execute infoClient.GetServerInfo(): ", err)
}
// TODO: Use serverInfo to check compatibility and start the right gRPC client.
if serverInfo != nil {
log.Printf("ServerInfo: %v\n", serverInfo)
}

c := new(client)
sockAddr := fmt.Sprintf("%s:%s", function.Protocol, opts.sockAddr)
conn, err := grpc.Dial(sockAddr, grpc.WithTransportCredentials(insecure.NewCredentials()),
Expand Down
12 changes: 10 additions & 2 deletions pkg/function/client/options.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package client

type options struct {
sockAddr string
maxMessageSize int
sockAddr string
maxMessageSize int
infoSvrSockAddr string
}

// Option is the interface to apply options.
Expand All @@ -21,3 +22,10 @@ func WithMaxMessageSize(size int) Option {
opts.maxMessageSize = size
}
}

// WithInfoServerSocketAddr start the client with the given info server sock addr. This is mainly used for testing purpose.
func WithInfoServerSocketAddr(addr string) Option {
return func(opts *options) {
opts.infoSvrSockAddr = addr
}
}
12 changes: 10 additions & 2 deletions pkg/function/server/options.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package server

type options struct {
sockAddr string
maxMessageSize int
sockAddr string
maxMessageSize int
infoSvrSockAddr string
}

// Option is the interface to apply options.
Expand All @@ -21,3 +22,10 @@ func WithSockAddr(addr string) Option {
opts.sockAddr = addr
}
}

// WithInfoServerSocketAddr sets the info server socket address.
func WithInfoServerSocketAddr(addr string) Option {
return func(opts *options) {
opts.infoSvrSockAddr = addr
}
}
14 changes: 12 additions & 2 deletions pkg/function/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

functionpb "github.com/numaproj/numaflow-go/pkg/apis/proto/function/v1"
functionsdk "github.com/numaproj/numaflow-go/pkg/function"
"github.com/numaproj/numaflow-go/pkg/info"
infoserver "github.com/numaproj/numaflow-go/pkg/info/server"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -83,8 +85,9 @@ func (s *server) RegisterReducer(r functionsdk.ReduceHandler) *server {
// Start starts the gRPC server via unix domain socket at configs.Addr and return error.
func (s *server) Start(ctx context.Context, inputOptions ...Option) error {
var opts = &options{
sockAddr: functionsdk.Addr,
maxMessageSize: functionsdk.DefaultMaxMessageSize,
sockAddr: functionsdk.Addr,
maxMessageSize: functionsdk.DefaultMaxMessageSize,
infoSvrSockAddr: info.SocketAddress,
}

for _, inputOption := range inputOptions {
Expand All @@ -106,10 +109,17 @@ func (s *server) Start(ctx context.Context, inputOptions ...Option) error {
ctxWithSignal, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer stop()

go func() {
if err := infoserver.Start(ctxWithSignal, infoserver.WithSocketAddress(opts.infoSvrSockAddr)); err != nil {
log.Fatalf("Failed to start info server: %v", err)
}
}()

lis, err := net.Listen(functionsdk.Protocol, opts.sockAddr)
if err != nil {
return fmt.Errorf("failed to execute net.Listen(%q, %q): %v", functionsdk.Protocol, functionsdk.Addr, err)
}
defer func() { _ = lis.Close() }()
grpcServer := grpc.NewServer(
grpc.MaxRecvMsgSize(opts.maxMessageSize),
grpc.MaxSendMsgSize(opts.maxMessageSize),
Expand Down
49 changes: 36 additions & 13 deletions pkg/function/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@ type fields struct {
}

func Test_server_map(t *testing.T) {
file, err := os.CreateTemp("/tmp", "numaflow-test.sock")

socketFile, err := os.CreateTemp("/tmp", "numaflow-test.sock")
assert.NoError(t, err)
defer func() {
err = os.RemoveAll(file.Name())
err = os.RemoveAll(socketFile.Name())
assert.NoError(t, err)
}()

infoSocketFile, err := os.CreateTemp("/tmp", "numaflow-test-info.sock")
assert.NoError(t, err)
defer func() {
err = os.RemoveAll(infoSocketFile.Name())
assert.NoError(t, err)
}()

tests := []struct {
name string
fields fields
Expand All @@ -50,9 +59,9 @@ func Test_server_map(t *testing.T) {
// note: using actual UDS connection
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go New().RegisterMapper(tt.fields.mapHandler).Start(ctx, WithSockAddr(file.Name()))
go New().RegisterMapper(tt.fields.mapHandler).Start(ctx, WithSockAddr(socketFile.Name()), WithInfoServerSocketAddr(infoSocketFile.Name()))

c, err := client.New(client.WithSockAddr(file.Name()))
c, err := client.New(client.WithSockAddr(socketFile.Name()), client.WithInfoServerSocketAddr(infoSocketFile.Name()))
assert.NoError(t, err)
defer func() {
err = c.CloseConn(ctx)
Expand Down Expand Up @@ -82,12 +91,20 @@ func Test_server_map(t *testing.T) {
}

func Test_server_mapT(t *testing.T) {
file, err := os.CreateTemp("/tmp", "numaflow-test.sock")
socketFile, err := os.CreateTemp("/tmp", "numaflow-test.sock")
assert.NoError(t, err)
defer func() {
err = os.RemoveAll(file.Name())
err = os.RemoveAll(socketFile.Name())
assert.NoError(t, err)
}()

infoSocketFile, err := os.CreateTemp("/tmp", "numaflow-test-info.sock")
assert.NoError(t, err)
defer func() {
err = os.RemoveAll(infoSocketFile.Name())
assert.NoError(t, err)
}()

tests := []struct {
name string
fields fields
Expand All @@ -107,9 +124,9 @@ func Test_server_mapT(t *testing.T) {
// note: using actual UDS connection
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go New().RegisterMapperT(tt.fields.mapTHandler).Start(ctx, WithSockAddr(file.Name()))
go New().RegisterMapperT(tt.fields.mapTHandler).Start(ctx, WithSockAddr(socketFile.Name()), WithInfoServerSocketAddr(infoSocketFile.Name()))

c, err := client.New(client.WithSockAddr(file.Name()))
c, err := client.New(client.WithSockAddr(socketFile.Name()), client.WithInfoServerSocketAddr(infoSocketFile.Name()))
assert.NoError(t, err)
defer func() {
err = c.CloseConn(ctx)
Expand Down Expand Up @@ -145,6 +162,14 @@ func Test_server_reduce(t *testing.T) {
err = os.RemoveAll(file.Name())
assert.NoError(t, err)
}()

infoSocketFile, err := os.CreateTemp("/tmp", "numaflow-test-info.sock")
assert.NoError(t, err)
defer func() {
err = os.RemoveAll(infoSocketFile.Name())
assert.NoError(t, err)
}()

var testKey = "reduce_key"
tests := []struct {
name string
Expand Down Expand Up @@ -187,9 +212,9 @@ func Test_server_reduce(t *testing.T) {
// note: using actual UDS connection
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go New().RegisterReducer(tt.fields.reduceHandler).Start(ctx, WithSockAddr(file.Name()))
go New().RegisterReducer(tt.fields.reduceHandler).Start(ctx, WithSockAddr(file.Name()), WithInfoServerSocketAddr(infoSocketFile.Name()))

c, err := client.New(client.WithSockAddr(file.Name()))
c, err := client.New(client.WithSockAddr(file.Name()), client.WithInfoServerSocketAddr(infoSocketFile.Name()))
assert.NoError(t, err)
defer func() {
err = c.CloseConn(ctx)
Expand Down Expand Up @@ -218,9 +243,7 @@ func Test_server_reduce(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
for _, d := range resultDatumList {
dList.Elements = append(dList.Elements, d)
}
dList.Elements = append(dList.Elements, resultDatumList...)
}()

wg.Wait()
Expand Down
104 changes: 104 additions & 0 deletions pkg/info/client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package client

import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"time"

"github.com/numaproj/numaflow-go/pkg/info"
)

type InfoClient struct {
client *http.Client
}

type options struct {
socketAddress string
clientTimeout time.Duration
}

type Option func(*options)

func WithSocketAddress(addr string) Option {
return func(o *options) {
o.socketAddress = addr
}
}

func WithClientTimeout(timeout time.Duration) Option {
return func(o *options) {
o.clientTimeout = timeout
}
}

// NewInfoClient creates a new info client
func NewInfoClient(opts ...Option) *InfoClient {
options := &options{
socketAddress: info.SocketAddress,
clientTimeout: 3 * time.Second,
}
for _, opt := range opts {
opt(options)
}

var httpClient *http.Client
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", options.socketAddress)
}
httpClient = &http.Client{
Transport: httpTransport,
Timeout: options.clientTimeout,
}
return &InfoClient{
client: httpClient,
}
}

func (c *InfoClient) waitUntilReady(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return fmt.Errorf("failed to wait for ready: %w", ctx.Err())
default:
if resp, err := c.client.Get("http://unix/ready"); err == nil {
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
if resp.StatusCode < 300 {
return nil
}
}
time.Sleep(1 * time.Second)
}
}
}

// GetServerInfo gets the server info
func (c *InfoClient) GetServerInfo(ctx context.Context) (*info.ServerInfo, error) {
cctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := c.waitUntilReady(cctx); err != nil {
return nil, fmt.Errorf("info server is not ready: %w", err)
}
resp, err := c.client.Get("http://unix/info")
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("http.NewRequestWithContext failed with status, %s", resp.Status)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("io.ReadAll failed: %w", err)
}
info := &info.ServerInfo{}
if err := json.Unmarshal(data, info); err != nil {
return nil, fmt.Errorf("json.Unmarshal failed: %w", err)
}
return info, nil
}
39 changes: 39 additions & 0 deletions pkg/info/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package client

import (
"context"
"os"
"testing"
"time"

"github.com/numaproj/numaflow-go/pkg/info"
"github.com/numaproj/numaflow-go/pkg/info/server"
"github.com/stretchr/testify/assert"
)

func Test_client(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

file, err := os.CreateTemp("/tmp", "test-info.sock")
assert.NoError(t, err)
defer func() {
err = os.RemoveAll(file.Name())
assert.NoError(t, err)
}()

go func() {
if err := server.Start(ctx, server.WithSocketAddress(file.Name())); err != nil {
t.Errorf("Start() error = %v", err)
}
}()

c := NewInfoClient(WithSocketAddress(file.Name()))
cctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
err = c.waitUntilReady(cctx)
assert.NoError(t, err)
si, err := c.GetServerInfo(ctx)
assert.NoError(t, err)
assert.Equal(t, si.Language, info.Go)
}
10 changes: 10 additions & 0 deletions pkg/info/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Package info starts an HTTP service to provide the gRPC server information.
//
// The server information can be used by the client to determine:
// - what is right protocol to use (UDS or TCP)
// - what is the numaflow sdk version used by the server
// - what is language used by the server
//
// The gPRC server (UDF, UDSink, etc) must start the InfoServer with correct ServerInfo populated when it starts.
// The client is supposed to call the InfoServer to get the server information, before it starts to communicate with the gRPC server.
package info
Loading