Skip to content

Commit

Permalink
use error helper
Browse files Browse the repository at this point in the history
  • Loading branch information
firefart committed Dec 12, 2023
1 parent 5e05a13 commit d7958b1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *MyCustomHandler) Refresh(ctx context.Context) {
}
}

func (s *MyCustomHandler) ReadFromRemote(remote io.ReadCloser, client io.WriteCloser) error {
func (s *MyCustomHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error {
i, err := io.Copy(client, remote)
if err != nil {
return err
Expand All @@ -146,7 +146,7 @@ func (s *MyCustomHandler) ReadFromRemote(remote io.ReadCloser, client io.WriteCl
return nil
}

func (s *MyCustomHandler) ReadFromClient(client io.ReadCloser, remote io.WriteCloser) error {
func (s *MyCustomHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error {
i, err := io.Copy(remote, client)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions parsers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ the address is a version-6 IP address, with a length of 16 octets.
func parseRequest(buf []byte) (*Request, *Error) {
r := &Request{}
if len(buf) < 7 {
return nil, &Error{Reason: RequestReplyConnectionRefused, Err: fmt.Errorf("invalid request header length (%d)", len(buf))}
return nil, NewError(RequestReplyConnectionRefused, fmt.Errorf("invalid request header length (%d)", len(buf)))
}
version := buf[0]
switch version {
Expand All @@ -60,7 +60,7 @@ func parseRequest(buf []byte) (*Request, *Error) {
case byte(Version5):
r.Version = Version5
default:
return nil, &Error{Reason: RequestReplyConnectionRefused, Err: fmt.Errorf("Invalid Socks version %#x", version)}
return nil, NewError(RequestReplyConnectionRefused, fmt.Errorf("Invalid Socks version %#x", version))
}
cmd := buf[1]
switch cmd {
Expand All @@ -71,7 +71,7 @@ func parseRequest(buf []byte) (*Request, *Error) {
// case byte(RequestCmdAssociate):
// r.Command = RequestCmdAssociate
default:
return nil, &Error{Reason: RequestReplyCommandNotSupported, Err: fmt.Errorf("Command %#x not supported", cmd)}
return nil, NewError(RequestReplyCommandNotSupported, fmt.Errorf("Command %#x not supported", cmd))
}
addresstype := buf[3]
switch addresstype {
Expand All @@ -82,7 +82,7 @@ func parseRequest(buf []byte) (*Request, *Error) {
case byte(RequestAddressTypeDomainname):
r.AddressType = RequestAddressTypeDomainname
default:
return nil, &Error{Reason: RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("AddressType %#x not supported", addresstype)}
return nil, NewError(RequestReplyAddressTypeNotSupported, fmt.Errorf("AddressType %#x not supported", addresstype))
}

switch r.AddressType {
Expand All @@ -100,7 +100,7 @@ func parseRequest(buf []byte) (*Request, *Error) {
p := buf[5+addrLen : 5+addrLen+2]
r.DestinationPort = binary.BigEndian.Uint16(p)
default:
return nil, &Error{Reason: RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("AddressType %#x not supported", addresstype)}
return nil, NewError(RequestReplyAddressTypeNotSupported, fmt.Errorf("AddressType %#x not supported", addresstype))
}

return r, nil
Expand Down
22 changes: 11 additions & 11 deletions socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ func (p *Proxy) socks(ctx context.Context, conn net.Conn) *Error {
// stop refreshing the connection
cancel()
if err := <-errChannel1; err != nil {
return &Error{Reason: RequestReplyHostUnreachable, Err: err}
return NewError(RequestReplyHostUnreachable, err)
}
if err := <-errChannel2; err != nil {
return &Error{Reason: RequestReplyHostUnreachable, Err: err}
return NewError(RequestReplyHostUnreachable, err)
}
p.Log.Debug("end of connection handling")

Expand Down Expand Up @@ -143,18 +143,18 @@ func (p *Proxy) socksErrorReply(ctx context.Context, conn io.ReadWriteCloser, re
func (p *Proxy) handleConnect(ctx context.Context, conn io.ReadWriteCloser) *Error {
buf, err := connectionRead(ctx, conn, p.Timeout)
if err != nil {
return &Error{Reason: RequestReplyConnectionRefused, Err: err}
return NewError(RequestReplyConnectionRefused, err)
}
header, err := parseHeader(buf)
if err != nil {
return &Error{Reason: RequestReplyConnectionRefused, Err: err}
return NewError(RequestReplyConnectionRefused, err)
}
switch header.Version {
case Version4:
return &Error{Reason: RequestReplyCommandNotSupported, Err: fmt.Errorf("socks4 not yet implemented")}
return NewError(RequestReplyCommandNotSupported, fmt.Errorf("socks4 not yet implemented"))
case Version5:
default:
return &Error{Reason: RequestReplyCommandNotSupported, Err: fmt.Errorf("version %#x not yet implemented", byte(header.Version))}
return NewError(RequestReplyCommandNotSupported, fmt.Errorf("version %#x not yet implemented", byte(header.Version)))
}

methodSupported := false
Expand All @@ -165,22 +165,22 @@ func (p *Proxy) handleConnect(ctx context.Context, conn io.ReadWriteCloser) *Err
}
}
if !methodSupported {
return &Error{Reason: RequestReplyMethodNotSupported, Err: fmt.Errorf("we currently only support no authentication")}
return NewError(RequestReplyMethodNotSupported, fmt.Errorf("we currently only support no authentication"))
}
reply := make([]byte, 2)
reply[0] = byte(Version5)
reply[1] = byte(MethodNoAuthRequired)
err = connectionWrite(ctx, conn, reply, p.Timeout)
if err != nil {
return &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("could not send connect reply: %w", err)}
return NewError(RequestReplyGeneralFailure, fmt.Errorf("could not send connect reply: %w", err))
}
return nil
}

func (p *Proxy) handleRequest(ctx context.Context, conn io.ReadWriteCloser) (*Request, *Error) {
buf, err := connectionRead(ctx, conn, p.Timeout)
if err != nil {
return nil, &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("error on ConnectionRead: %w", err)}
return nil, NewError(RequestReplyGeneralFailure, fmt.Errorf("error on ConnectionRead: %w", err))
}
request, err2 := parseRequest(buf)
if err2 != nil {
Expand All @@ -192,11 +192,11 @@ func (p *Proxy) handleRequest(ctx context.Context, conn io.ReadWriteCloser) (*Re
func (p *Proxy) handleRequestReply(ctx context.Context, conn io.ReadWriteCloser, request *Request) *Error {
repl, err := requestReply(request, RequestReplySucceeded)
if err != nil {
return &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("error on requestReply: %w", err)}
return NewError(RequestReplyGeneralFailure, fmt.Errorf("error on requestReply: %w", err))
}
err = connectionWrite(ctx, conn, repl, p.Timeout)
if err != nil {
return &Error{Reason: RequestReplyGeneralFailure, Err: fmt.Errorf("error on RequestResponse: %w", err)}
return NewError(RequestReplyGeneralFailure, fmt.Errorf("error on RequestResponse: %w", err))
}

return nil
Expand Down

0 comments on commit d7958b1

Please sign in to comment.