Skip to content

Commit

Permalink
use mime type
Browse files Browse the repository at this point in the history
Signed-off-by: He Xian <[email protected]>
  • Loading branch information
hexian000 committed Oct 7, 2024
1 parent 22292bb commit 51221be
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions v3/proto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@ import (
"errors"
"io"
"math"
"mime"
"net"
)

const Type = "application/x-tlswrapper; version=3"
var (
versionStr = "3"

mimeType = "application/x-tlswrapper-msg"
mimeParams = map[string]string{"version": versionStr}

Type = mime.FormatMediaType(mimeType, mimeParams)
)

const (
MsgHello = iota
Expand All @@ -28,8 +36,9 @@ type ServerMsg struct {
}

var (
ErrMsgTooLong = errors.New("message too long")
ErrUnsupportedProtocol = errors.New("unsupported protocol")
ErrMsgTooLong = errors.New("message too long")
ErrUnsupportedProtocol = errors.New("unsupported protocol")
ErrIncompatiableVersion = errors.New("incompatible protocol version")
)

func sendmsg(conn net.Conn, msg interface{}) error {
Expand Down Expand Up @@ -63,6 +72,24 @@ func recvmsg(conn net.Conn, msg interface{}) error {
return nil
}

func checkType(s string) error {
mediatype, params, err := mime.ParseMediaType(s)
if err != nil {
return err
}
if mediatype != mimeType {
return ErrUnsupportedProtocol
}
version, ok := params["version"]
if !ok {
return ErrUnsupportedProtocol
}
if version != versionStr {
return ErrIncompatiableVersion
}
return nil
}

func Roundtrip(conn net.Conn, req *ClientMsg) (*ServerMsg, error) {
if err := sendmsg(conn, req); err != nil {
return nil, err
Expand All @@ -71,8 +98,8 @@ func Roundtrip(conn net.Conn, req *ClientMsg) (*ServerMsg, error) {
if err := recvmsg(conn, rsp); err != nil {
return nil, err
}
if rsp.Type != Type {
return nil, ErrUnsupportedProtocol
if err := checkType(rsp.Type); err != nil {
return nil, err
}
return rsp, nil
}
Expand All @@ -82,8 +109,8 @@ func RecvRequest(conn net.Conn) (*ClientMsg, error) {
if err := recvmsg(conn, req); err != nil {
return nil, err
}
if req.Type != Type {
return nil, ErrUnsupportedProtocol
if err := checkType(req.Type); err != nil {
return nil, err
}
return req, nil
}
Expand Down

0 comments on commit 51221be

Please sign in to comment.