diff --git a/Dockerfile b/Dockerfile index e7d72a42..afe91c7f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -240,6 +240,11 @@ ENV VPN_SERVICE_PROVIDER=pia \ SHADOWSOCKS_PASSWORD= \ SHADOWSOCKS_PASSWORD_SECRETFILE=/run/secrets/shadowsocks_password \ SHADOWSOCKS_CIPHER=chacha20-ietf-poly1305 \ + # Socks5 + SOCKS5_ENABLED=off \ + SOCKS5_LISTENING_ADDRESS=":1080" \ + SOCKS5_USER= \ + SOCKS5_PASSWORD= \ # Control server HTTP_CONTROL_SERVER_LOG=on \ HTTP_CONTROL_SERVER_ADDRESS=":8000" \ @@ -271,7 +276,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ PUID=1000 \ PGID=1000 ENTRYPOINT ["/gluetun-entrypoint"] -EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp +EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck ARG TARGETPLATFORM RUN apk add --no-cache --update -l wget && \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 97209889..bc3c3447 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -41,6 +41,7 @@ import ( "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/server" "github.com/qdm12/gluetun/internal/shadowsocks" + "github.com/qdm12/gluetun/internal/socks5" "github.com/qdm12/gluetun/internal/storage" updater "github.com/qdm12/gluetun/internal/updater/loop" "github.com/qdm12/gluetun/internal/updater/resolver" @@ -411,6 +412,18 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return fmt.Errorf("starting public ip loop: %w", err) } + socks5Loop := socks5.NewLoop(socks5.Settings{ + Enabled: *allSettings.Socks5.Enabled, + Username: *allSettings.Socks5.Username, + Password: *allSettings.Socks5.Password, + Address: allSettings.Socks5.ListeningAddress, + Logger: logger.New(log.SetComponent("socks5")), + }) + socks5RunError, err := socks5Loop.Start(ctx) + if err != nil { + return fmt.Errorf("starting SOCKS5 server loop: %w", err) + } + healthLogger := logger.New(log.SetComponent("healthcheck")) healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger) healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler( @@ -506,7 +519,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, String() string Stop() error }{ - portForwardLooper, publicIPLooper, + portForwardLooper, publicIPLooper, socks5Loop, } for _, stopper := range stoppers { err := stopper.Stop() @@ -518,6 +531,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, logger.Errorf("port forwarding loop crashed: %s", err) case err := <-publicIPRunError: logger.Errorf("public IP loop crashed: %s", err) + case err := <-socks5RunError: + logger.Errorf("SOCKS5 server loop crashed: %s", err) } return orderHandler.Shutdown(context.Background()) diff --git a/internal/configuration/settings/settings.go b/internal/configuration/settings/settings.go index 88e02669..f4529885 100644 --- a/internal/configuration/settings/settings.go +++ b/internal/configuration/settings/settings.go @@ -20,6 +20,7 @@ type Settings struct { HTTPProxy HTTPProxy Log Log PublicIP PublicIP + Socks5 Socks5 Shadowsocks Shadowsocks Storage Storage System System @@ -49,6 +50,7 @@ func (s *Settings) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Support "http proxy": s.HTTPProxy.validate, "log": s.Log.validate, "public ip check": s.PublicIP.validate, + "socks5": s.Socks5.validate, "shadowsocks": s.Shadowsocks.validate, "storage": s.Storage.validate, "system": s.System.validate, @@ -81,6 +83,7 @@ func (s *Settings) copy() (copied Settings) { HTTPProxy: s.HTTPProxy.copy(), Log: s.Log.copy(), PublicIP: s.PublicIP.copy(), + Socks5: s.Socks5.copy(), Shadowsocks: s.Shadowsocks.copy(), Storage: s.Storage.copy(), System: s.System.copy(), @@ -104,6 +107,7 @@ func (s *Settings) OverrideWith(other Settings, patchedSettings.HTTPProxy.overrideWith(other.HTTPProxy) patchedSettings.Log.overrideWith(other.Log) patchedSettings.PublicIP.overrideWith(other.PublicIP) + patchedSettings.Socks5.overrideWith(other.Socks5) patchedSettings.Shadowsocks.overrideWith(other.Shadowsocks) patchedSettings.Storage.overrideWith(other.Storage) patchedSettings.System.overrideWith(other.System) @@ -131,6 +135,7 @@ func (s *Settings) SetDefaults() { s.Log.setDefaults() s.IPv6.setDefaults() s.PublicIP.setDefaults() + s.Socks5.setDefaults() s.Shadowsocks.setDefaults() s.Storage.SetDefaults() s.System.setDefaults() @@ -154,6 +159,7 @@ func (s Settings) toLinesNode() (node *gotree.Node) { node.AppendNode(s.Log.toLinesNode()) node.AppendNode(s.IPv6.toLinesNode()) node.AppendNode(s.Health.toLinesNode()) + node.AppendNode(s.Socks5.toLinesNode()) node.AppendNode(s.Shadowsocks.toLinesNode()) node.AppendNode(s.HTTPProxy.toLinesNode()) node.AppendNode(s.ControlServer.toLinesNode()) @@ -212,6 +218,7 @@ func (s *Settings) Read(r *reader.Reader, warner Warner) (err error) { "public ip": func(r *reader.Reader) error { return s.PublicIP.read(r, warner) }, + "socks5": s.Socks5.read, "shadowsocks": s.Shadowsocks.read, "storage": s.Storage.Read, "system": s.System.read, diff --git a/internal/configuration/settings/settings_test.go b/internal/configuration/settings/settings_test.go index 1fbb18a6..96ea4358 100644 --- a/internal/configuration/settings/settings_test.go +++ b/internal/configuration/settings/settings_test.go @@ -81,6 +81,8 @@ func Test_Settings_String(t *testing.T) { | | ├── 1.1.1.1 | | └── 8.8.8.8 | └── Restart VPN on healthcheck failure: yes +├── SOCKS5 proxy server settings: +| └── Enabled: no ├── Shadowsocks server settings: | └── Enabled: no ├── HTTP proxy settings: diff --git a/internal/configuration/settings/socks5.go b/internal/configuration/settings/socks5.go new file mode 100644 index 00000000..7c2cc115 --- /dev/null +++ b/internal/configuration/settings/socks5.go @@ -0,0 +1,91 @@ +package settings + +import ( + "errors" + "fmt" + "os" + + "github.com/qdm12/gosettings" + "github.com/qdm12/gosettings/reader" + "github.com/qdm12/gosettings/validate" + "github.com/qdm12/gotree" +) + +// Socks5 contains settings to configure the Socks5 proxy server. +type Socks5 struct { + Enabled *bool + ListeningAddress string + Username *string + Password *string +} + +func (s Socks5) validate() (err error) { + err = validate.ListeningAddress(s.ListeningAddress, os.Getuid()) + if err != nil { + return fmt.Errorf("server listening address is not valid: %w", err) + } + + switch { + case *s.Username != "" && *s.Password == "": + return errors.New("password must be set if username is set") + case *s.Username == "" && *s.Password != "": + return errors.New("username must be set if password is set") + } + + return nil +} + +func (s *Socks5) copy() (copied Socks5) { + return Socks5{ + Enabled: gosettings.CopyPointer(s.Enabled), + ListeningAddress: s.ListeningAddress, + Username: gosettings.CopyPointer(s.Username), + Password: gosettings.CopyPointer(s.Password), + } +} + +func (s *Socks5) overrideWith(other Socks5) { + s.Enabled = gosettings.OverrideWithPointer(s.Enabled, other.Enabled) + s.ListeningAddress = gosettings.OverrideWithComparable(s.ListeningAddress, other.ListeningAddress) + s.Username = gosettings.OverrideWithPointer(s.Username, other.Username) + s.Password = gosettings.OverrideWithPointer(s.Password, other.Password) +} + +func (s *Socks5) setDefaults() { + s.Enabled = gosettings.DefaultPointer(s.Enabled, false) + s.ListeningAddress = gosettings.DefaultComparable(s.ListeningAddress, ":1080") + s.Username = gosettings.DefaultPointer(s.Username, "") + s.Password = gosettings.DefaultPointer(s.Password, "") +} + +func (s Socks5) String() string { + return s.toLinesNode().String() +} + +func (s Socks5) toLinesNode() (node *gotree.Node) { + node = gotree.New("SOCKS5 proxy server settings:") + node.Appendf("Enabled: %s", gosettings.BoolToYesNo(s.Enabled)) + if !*s.Enabled { + return node + } + + node.Appendf("Listening address: %s", s.ListeningAddress) + if *s.Username != "" || *s.Password != "" { + node.Appendf("Username: %s", *s.Username) + node.Appendf("Password: %s", gosettings.ObfuscateKey(*s.Password)) + } + return node +} + +func (s *Socks5) read(r *reader.Reader) (err error) { + s.Enabled, err = r.BoolPtr("SOCKS5_ENABLED") + if err != nil { + return err + } + + s.ListeningAddress = r.String("SOCKS5_LISTENING_ADDRESS") + s.Username = r.Get("SOCKS5_USER", reader.ForceLowercase(false)) + s.Password = r.Get("SOCKS5_PASSWORD", reader.ForceLowercase(false)) + + return nil +} diff --git a/internal/socks5/constants.go b/internal/socks5/constants.go new file mode 100644 index 00000000..bb185ad6 --- /dev/null +++ b/internal/socks5/constants.go @@ -0,0 +1,86 @@ +package socks5 + +import "fmt" + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-3 +type authMethod byte + +const ( + authNotRequired authMethod = 0 + authGssapi authMethod = 1 + authUsernamePassword authMethod = 2 + authNotAcceptable authMethod = 255 +) + +func (a authMethod) String() string { + switch a { + case authNotRequired: + return "no authentication required" + case authGssapi: + return "GSSAPI" + case authUsernamePassword: + return "username/password" + case authNotAcceptable: + return "no acceptable methods" + default: + return fmt.Sprintf("unknown method (%d)", a) + } +} + +// Subnegotiation version +// See https://datatracker.ietf.org/doc/html/rfc1929#section-2 +const ( + authUsernamePasswordSubNegotiation1 byte = 1 +) + +// SOCKS versions. +const ( + socks5Version byte = 5 +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 +type cmdType byte + +const ( + connect cmdType = 1 + bind cmdType = 2 + udpAssociate cmdType = 3 +) + +func (c cmdType) String() string { + switch c { + case connect: + return "connect" + case bind: + return "bind" + case udpAssociate: + return "UDP associate" + default: + return fmt.Sprintf("unknown command (%d)", c) + } +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 and +// https://datatracker.ietf.org/doc/html/rfc1928#section-5 +type addrType byte + +const ( + ipv4 addrType = 1 + domainName addrType = 3 + ipv6 addrType = 4 +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-6 +type replyCode byte + +const ( + succeeded replyCode = iota + generalServerFailure + connectionNotAllowedByRuleset + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addressTypeNotSupported +) diff --git a/internal/socks5/interfaces.go b/internal/socks5/interfaces.go new file mode 100644 index 00000000..a9951848 --- /dev/null +++ b/internal/socks5/interfaces.go @@ -0,0 +1,6 @@ +package socks5 + +type Logger interface { + Infof(format string, a ...interface{}) + Warnf(format string, a ...interface{}) +} diff --git a/internal/socks5/loop.go b/internal/socks5/loop.go new file mode 100644 index 00000000..18b77bff --- /dev/null +++ b/internal/socks5/loop.go @@ -0,0 +1,106 @@ +package socks5 + +import ( + "context" + "sync" + "time" + + "github.com/qdm12/goservices" +) + +type Loop struct { + settings Settings + + mutex sync.Mutex + runCancel context.CancelFunc + runDone <-chan error +} + +func NewLoop(settings Settings) *Loop { + return &Loop{ + settings: settings, + } +} + +func (l *Loop) String() string { + return "SOCKS5 server loop" +} + +func (l *Loop) Start(_ context.Context) (runError <-chan error, err error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + var runCtx context.Context + runCtx, l.runCancel = context.WithCancel(context.Background()) + + runDone := make(chan error) + l.runDone = runDone + + go run(runCtx, runDone, l.settings) + + return nil, nil //nolint:nilnil +} + +func run(ctx context.Context, done chan<- error, settings Settings) { + defer close(done) + logger := settings.Logger + + for ctx.Err() == nil { + var server goservices.Service + if settings.Enabled { + server = newServer(settings) + } else { + server = new(noopService) + } + + errorCh, err := server.Start(ctx) + if err != nil { + logger.Warnf("failed starting SOCKS5 server: %s", err) + waitBeforeRetry(ctx) + continue + } + + select { + case <-ctx.Done(): + done <- server.Stop() + return + case err := <-errorCh: + if ctx.Err() != nil { + return + } + logger.Warnf("SOCKS5 server crashed: %s", err) + waitBeforeRetry(ctx) + } + } +} + +func (l *Loop) Stop() (err error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + l.runCancel() + return <-l.runDone +} + +func waitBeforeRetry(ctx context.Context) { + const retryDelay = 10 * time.Second + timer := time.NewTimer(retryDelay) + select { + case <-timer.C: + case <-ctx.Done(): + } +} + +type noopService struct{} + +func (s noopService) Start(_ context.Context) (runErr <-chan error, err error) { + return nil, nil //nolint:nilnil +} + +func (s noopService) Stop() error { + return nil +} + +func (s noopService) String() string { + return "noop service" +} diff --git a/internal/socks5/mocks_generate_test.go b/internal/socks5/mocks_generate_test.go new file mode 100644 index 00000000..5797ec72 --- /dev/null +++ b/internal/socks5/mocks_generate_test.go @@ -0,0 +1,3 @@ +package socks5 + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger diff --git a/internal/socks5/mocks_test.go b/internal/socks5/mocks_test.go new file mode 100644 index 00000000..79aaa8e2 --- /dev/null +++ b/internal/socks5/mocks_test.go @@ -0,0 +1,68 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/socks5 (interfaces: Logger) + +// Package socks5 is a generated GoMock package. +package socks5 + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Infof mocks base method. +func (m *MockLogger) Infof(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Infof", varargs...) +} + +// Infof indicates an expected call of Infof. +func (mr *MockLoggerMockRecorder) Infof(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...) +} + +// Warnf mocks base method. +func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warnf", varargs...) +} + +// Warnf indicates an expected call of Warnf. +func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...) +} diff --git a/internal/socks5/response.go b/internal/socks5/response.go new file mode 100644 index 00000000..b65ee5e9 --- /dev/null +++ b/internal/socks5/response.go @@ -0,0 +1,109 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-6 +func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { //nolint:unparam + _, err := writer.Write([]byte{ + socksVersion, + byte(reply), + 0, // RSV byte + // The RFC requires a full response frame even for failures. + // Use IPv4 address type with zeroed address and port. + byte(ipv4), // ATYP + 0, 0, 0, 0, // BND.ADDR (zeroed) + 0, 0, // BND.PORT (zeroed) + }) + if err != nil { + c.logger.Warnf("failed writing failed response: %s", err) + } +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-6 +func (c *socksConn) encodeSuccessResponse(writer io.Writer, socksVersion byte, + reply replyCode, bindAddrType addrType, bindAddress string, + bindPort uint16, +) (err error) { + bindData, err := encodeBindData(bindAddrType, bindAddress, bindPort) + if err != nil { + return fmt.Errorf("encoding bind data: %w", err) + } + + const initialPacketLength = 3 + capacity := initialPacketLength + len(bindData) + packet := make([]byte, initialPacketLength, capacity) + packet[0] = socksVersion + packet[1] = byte(reply) + packet[2] = 0 // RSV byte + packet = append(packet, bindData...) + + _, err = writer.Write(packet) + if err != nil { + return fmt.Errorf("writing packet: %w", err) + } + return nil +} + +var ( + ErrIPVersionUnexpected = errors.New("ip version is unexpected") + ErrDomainNameTooLong = errors.New("domain name is too long") +) + +func encodeBindData(addrType addrType, address string, port uint16) ( + data []byte, err error, +) { + capacity := bindDataLength(addrType, address) + data = make([]byte, 0, capacity) + + data = append(data, byte(addrType)) + switch addrType { + case ipv4, ipv6: + ip, err := netip.ParseAddr(address) + if err != nil { + return nil, fmt.Errorf("parsing IP address: %w", err) + } + + switch { + case addrType == ipv4 && !ip.Is4(): + return nil, fmt.Errorf("%w: expected IPv4 for %s", ErrIPVersionUnexpected, ip) + case addrType == ipv6 && !ip.Is6(): + return nil, fmt.Errorf("%w: expected IPv6 for %s", ErrIPVersionUnexpected, ip) + } + data = append(data, ip.AsSlice()...) + case domainName: + const maxDomainNameLength = 255 + if len(address) > maxDomainNameLength { + return nil, fmt.Errorf("%w: %s", ErrDomainNameTooLong, address) + } + data = append(data, byte(len(address))) + data = append(data, []byte(address)...) + default: + panic(fmt.Sprintf("unsupported address type %d", addrType)) + } + data = binary.BigEndian.AppendUint16(data, port) + return data, nil +} + +func bindDataLength(addrType addrType, address string) (maxLength uint) { + maxLength++ // address type + switch addrType { + case ipv4: + maxLength += net.IPv4len + case domainName: + maxLength++ // domain name length + maxLength += uint(len([]byte(address))) + case ipv6: + maxLength += net.IPv6len + default: + panic("unsupported address type: " + fmt.Sprint(addrType)) + } + maxLength += 2 // port + return maxLength +} diff --git a/internal/socks5/server.go b/internal/socks5/server.go new file mode 100644 index 00000000..421c9883 --- /dev/null +++ b/internal/socks5/server.go @@ -0,0 +1,122 @@ +package socks5 + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" +) + +type server struct { + username string + password string + address string + logger Logger + + // internal fields + listener net.Listener + listening atomic.Bool + socksConnCtx context.Context //nolint:containedctx + socksConnCancel context.CancelFunc + done <-chan struct{} + stopping atomic.Bool +} + +func newServer(settings Settings) *server { + return &server{ + username: settings.Username, + password: settings.Password, + address: settings.Address, + logger: settings.Logger, + } +} + +func (s *server) String() string { + return "SOCKS5 server" +} + +func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) { + s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background()) + config := &net.ListenConfig{} + s.listener, err = config.Listen(ctx, "tcp", s.address) + if err != nil { + return nil, fmt.Errorf("listening on %s: %w", s.address, err) + } + s.listening.Store(true) + s.logger.Infof("SOCKS5 server listening on %s", s.listener.Addr()) + + ready := make(chan struct{}) + runErrCh := make(chan error) + runErr = runErrCh + done := make(chan struct{}) + s.done = done + go s.runServer(ready, runErrCh, done) + select { + case <-ready: + case <-ctx.Done(): + _ = s.Stop() + return nil, fmt.Errorf("starting server: %w", ctx.Err()) + } + return runErr, nil +} + +func (s *server) runServer(ready chan<- struct{}, + runErrCh chan<- error, done chan<- struct{}, +) { + close(ready) + defer close(done) + wg := new(sync.WaitGroup) + defer wg.Wait() + + dialer := &net.Dialer{} + for { + connection, err := s.listener.Accept() + if err != nil { + if !s.stopping.Load() { + _ = s.stop() + runErrCh <- fmt.Errorf("accepting connection: %w", err) + } + return + } + wg.Add(1) + go func(ctx context.Context, connection net.Conn, + dialer *net.Dialer, wg *sync.WaitGroup, + ) { + defer wg.Done() + socksConn := &socksConn{ + dialer: dialer, + username: s.username, + password: s.password, + clientConn: connection, + logger: s.logger, + } + err := socksConn.run(ctx) + if err != nil { + s.logger.Infof("running socks connection: %s", err) + } + }(s.socksConnCtx, connection, dialer, wg) + } +} + +func (s *server) Stop() (err error) { + s.stopping.Store(true) + err = s.stop() + <-s.done // wait for run goroutine to finish + s.stopping.Store(false) + return err +} + +func (s *server) stop() error { + s.listening.Store(false) + err := s.listener.Close() + s.socksConnCancel() // stop ongoing socks connections + return err +} + +func (s *server) listeningAddress() net.Addr { + if s.listening.Load() { + return s.listener.Addr() + } + return nil +} diff --git a/internal/socks5/settings.go b/internal/socks5/settings.go new file mode 100644 index 00000000..f88e940c --- /dev/null +++ b/internal/socks5/settings.go @@ -0,0 +1,9 @@ +package socks5 + +type Settings struct { + Enabled bool + Username string + Password string + Address string + Logger Logger +} diff --git a/internal/socks5/socks5.go b/internal/socks5/socks5.go new file mode 100644 index 00000000..3eb29ce8 --- /dev/null +++ b/internal/socks5/socks5.go @@ -0,0 +1,290 @@ +package socks5 + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" +) + +var ( + errNoMethodIdentifiers = errors.New("no method identifiers") + errNoValidMethodIdentifier = errors.New("no valid method identifier") +) + +type socksConn struct { + // Injected fields + dialer *net.Dialer + username string + password string + clientConn net.Conn + logger Logger +} + +func (c *socksConn) closeClientConn(ctxErr error) { + err := c.clientConn.Close() + if err != nil && ctxErr == nil { + c.logger.Warnf("closing client connection: %s", err) + } +} + +func (c *socksConn) run(ctx context.Context) error { + // Monitoring context cancellation to close the connection and stop + // reading operations on clientConn. + done := make(chan struct{}) + ctxWatcherDone := make(chan struct{}) + go func() { + defer close(ctxWatcherDone) + select { + case <-done: + case <-ctx.Done(): + // unblock read operations + c.closeClientConn(ctx.Err()) + } + }() + defer func() { + close(done) + <-ctxWatcherDone + }() + + authMethod := authNotRequired + if c.username != "" || c.password != "" { + authMethod = authUsernamePassword + } + + err := verifyFirstNegotiation(c.clientConn, authMethod) + if err != nil { + replyMethod := authMethod + if errors.Is(err, errNoMethodIdentifiers) || errors.Is(err, errNoValidMethodIdentifier) { + replyMethod = authNotAcceptable + } + _, writeErr := c.clientConn.Write([]byte{socks5Version, byte(replyMethod)}) + if writeErr != nil { + c.logger.Warnf("failed writing first negotiation reply: %s", writeErr) + } + c.closeClientConn(ctx.Err()) + return fmt.Errorf("verifying first negotiation: %w", err) + } + + _, err = c.clientConn.Write([]byte{socks5Version, byte(authMethod)}) + if err != nil { + c.closeClientConn(ctx.Err()) + return fmt.Errorf("writing first negotiation reply: %w", err) + } + + switch authMethod { + case authNotRequired, authNotAcceptable: + case authGssapi: + panic("not implemented") + case authUsernamePassword: + // See https://datatracker.ietf.org/doc/html/rfc1929#section-2 + err = usernamePasswordSubnegotiate(c.clientConn, c.username, c.password) + if err != nil { + // If the server returns a `failure' (STATUS value other than X'00') status, + // it MUST close the connection. + c.closeClientConn(ctx.Err()) + return fmt.Errorf("subnegotiating username and password: %w", err) + } + default: + panic(fmt.Sprintf("unimplemented auth method %d", authMethod)) + } + + err = c.handleRequest(ctx) + c.closeClientConn(ctx.Err()) + if err != nil { + return fmt.Errorf("handling request: %w", err) + } + return nil +} + +func (c *socksConn) handleRequest(ctx context.Context) error { + const socksVersion = socks5Version + request, err := decodeRequest(c.clientConn, socksVersion) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return err + } + if request.command != connect { + c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported) + return fmt.Errorf("command %s is not supported", request.command) + } + + destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port)) + destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return err + } + defer destinationConn.Close() + + destinationServerAddress := destinationConn.LocalAddr().String() + destinationAddr, destinationPortStr, err := net.SplitHostPort(destinationServerAddress) + if err != nil { + return fmt.Errorf("splitting destination address: %w", err) + } + destinationPort, err := strconv.ParseUint(destinationPortStr, 10, 16) + if err != nil { + return fmt.Errorf("port is malformed: %q", destinationPortStr) + } + + var bindAddrType addrType + if ip := net.ParseIP(destinationAddr); ip != nil { + if ip.To4() != nil { + bindAddrType = ipv4 + } else { + bindAddrType = ipv6 + } + } else { + bindAddrType = domainName + } + + err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, bindAddrType, + destinationAddr, uint16(destinationPort)) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return fmt.Errorf("writing successful %s response: %w", request.command, err) + } + + const capacity = 2 // if one goroutine fails, we don't want to leak the other one + errc := make(chan error, capacity) + go func() { + _, err := io.Copy(c.clientConn, destinationConn) + if err != nil { + err = fmt.Errorf("from backend to client: %w", err) + } + errc <- err + }() + go func() { + _, err := io.Copy(destinationConn, c.clientConn) + if err != nil { + err = fmt.Errorf("from client to backend: %w", err) + } + errc <- err + }() + select { + case err := <-errc: + return err + case <-ctx.Done(): + _ = destinationConn.Close() + _ = c.clientConn.Close() + return nil + } +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-3 +func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error { + const headerLength = 2 // version + nMethods bytes + header := make([]byte, headerLength) + _, err := io.ReadFull(reader, header) + if err != nil { + return fmt.Errorf("reading header: %w", err) + } + + if header[0] != socks5Version { + return fmt.Errorf("version is not supported: %d", header[0]) + } + + nMethods := header[1] + if nMethods == 0 { + return fmt.Errorf("%w", errNoMethodIdentifiers) + } + + methodIdentifiers := make([]byte, nMethods) + _, err = io.ReadFull(reader, methodIdentifiers) + if err != nil { + return fmt.Errorf("reading method identifiers: %w", err) + } + for _, methodIdentifier := range methodIdentifiers { + if methodIdentifier == byte(requiredMethod) { + return nil + } + } + + return makeNoAcceptableMethodError(requiredMethod, methodIdentifiers) +} + +func makeNoAcceptableMethodError(requiredAuthMethod authMethod, methodIdentifiers []byte) error { + methodNames := make([]string, len(methodIdentifiers)) + for i, methodIdentifier := range methodIdentifiers { + methodNames[i] = fmt.Sprintf("%q", authMethod(methodIdentifier)) + } + + return fmt.Errorf("%w: none of %s matches %s", + errNoValidMethodIdentifier, strings.Join(methodNames, ", "), + requiredAuthMethod) +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 +type request struct { + command cmdType + destination string + port uint16 + addressType addrType +} + +func decodeRequest(reader io.Reader, expectedVersion byte) (req request, err error) { + const headerLength = 4 + header := [headerLength]byte{} + _, err = io.ReadFull(reader, header[:]) + if err != nil { + return request{}, fmt.Errorf("reading header: %w", err) + } + + version := header[0] + switch { + case version != expectedVersion: + return request{}, fmt.Errorf("version is not supported: expected %d and got %d", + expectedVersion, version) + case header[2] != 0: + return request{}, fmt.Errorf("reserved header byte must be 0 but got %d", header[2]) + } + + req.command = cmdType(header[1]) + // header[2] is RSV byte + req.addressType = addrType(header[3]) + + switch req.addressType { + case ipv4: + var ip [4]byte + _, err = io.ReadFull(reader, ip[:]) + if err != nil { + return request{}, fmt.Errorf("reading IPv4 address: %w", err) + } + req.destination = netip.AddrFrom4(ip).String() + case ipv6: + var ip [16]byte + _, err = io.ReadFull(reader, ip[:]) + if err != nil { + return request{}, fmt.Errorf("reading IPv6 address: %w", err) + } + req.destination = netip.AddrFrom16(ip).String() + case domainName: + var header [1]byte + _, err = io.ReadFull(reader, header[:]) + if err != nil { + return request{}, fmt.Errorf("reading domain name header: %w", err) + } + domainName := make([]byte, header[0]) + _, err = io.ReadFull(reader, domainName) + if err != nil { + return request{}, fmt.Errorf("reading domain name bytes: %w", err) + } + req.destination = string(domainName) + default: + return request{}, fmt.Errorf("address type is not supported: %d", req.addressType) + } + + var portBytes [2]byte + _, err = io.ReadFull(reader, portBytes[:]) + if err != nil { + return request{}, fmt.Errorf("reading port: %w", err) + } + req.port = binary.BigEndian.Uint16(portBytes[:]) + + return req, nil +} diff --git a/internal/socks5/socks5_test.go b/internal/socks5/socks5_test.go new file mode 100644 index 00000000..5b16de16 --- /dev/null +++ b/internal/socks5/socks5_test.go @@ -0,0 +1,622 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "strconv" + "strings" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type noopLogger struct{} + +func (noopLogger) Infof(string, ...any) {} +func (noopLogger) Warnf(string, ...any) {} + +func TestServerProxy(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + username string + password string + }{ + "no_auth": {}, + "with_auth": { + username: "user", + password: "pass", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Backend TCP server: accepts one connection for the proxy to forward to. + backendListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + + backendConnCh := make(chan net.Conn) + go func() { + conn, err := backendListener.Accept() + if err != nil { + return + } + backendConnCh <- conn + }() + + server := newServer(Settings{ + Username: testCase.username, + Password: testCase.password, + Address: "127.0.0.1:0", + Logger: noopLogger{}, + }) + _, err = server.Start(t.Context()) + require.NoError(t, err) + t.Cleanup(func() { + _ = server.Stop() + _ = backendListener.Close() + }) + + // Dial through the SOCKS5 proxy to the backend. + // By the time dialSOCKS5 returns, the SOCKS5 server has already + // established the TCP connection to the backend, so backendConnCh + // is guaranteed to be populated. + clientConn := dialSOCKS5(t, server.listeningAddress().String(), + backendListener.Addr().String(), testCase.username, testCase.password) + defer clientConn.Close() + + backendConn := <-backendConnCh + defer backendConn.Close() + + // Verify client → backend direction. + clientMessage := []byte("hello from client") + _, err = clientConn.Write(clientMessage) + require.NoError(t, err) + + received := make([]byte, len(clientMessage)) + _, err = io.ReadFull(backendConn, received) + require.NoError(t, err) + assert.Equal(t, clientMessage, received) + + // Verify backend → client direction. + backendMessage := []byte("hello from backend") + _, err = backendConn.Write(backendMessage) + require.NoError(t, err) + + receivedByClient := make([]byte, len(backendMessage)) + _, err = io.ReadFull(clientConn, receivedByClient) + require.NoError(t, err) + assert.Equal(t, backendMessage, receivedByClient) + }) + } +} + +// dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password +// subnegotiation) and returns a connected net.Conn ready for data exchange. +func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn { + t.Helper() + + host, portStr, err := net.SplitHostPort(targetAddr) + require.NoError(t, err) + targetPort, err := strconv.Atoi(portStr) + require.NoError(t, err) + + conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr) + require.NoError(t, err) + + var method authMethod + if username != "" || password != "" { + method = authUsernamePassword + } else { + method = authNotRequired + } + _, err = conn.Write([]byte{socks5Version, 1, byte(method)}) + require.NoError(t, err) + + var methodResp [2]byte + _, err = io.ReadFull(conn, methodResp[:]) + require.NoError(t, err) + require.Equal(t, socks5Version, methodResp[0]) + require.Equal(t, byte(method), methodResp[1]) + + if method == authUsernamePassword { + packet := []byte{authUsernamePasswordSubNegotiation1, byte(len(username))} + packet = append(packet, []byte(username)...) + packet = append(packet, byte(len(password))) + packet = append(packet, []byte(password)...) + _, err = conn.Write(packet) + require.NoError(t, err) + + var subnegResp [2]byte + _, err = io.ReadFull(conn, subnegResp[:]) + require.NoError(t, err) + require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0]) + require.Equal(t, byte(0), subnegResp[1]) + } + + var connectRequest []byte + if ip := net.ParseIP(host).To4(); ip != nil { + connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)} + connectRequest = append(connectRequest, ip...) + } else { + connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))} + connectRequest = append(connectRequest, []byte(host)...) + } + connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec + _, err = conn.Write(connectRequest) + require.NoError(t, err) + + var responseHeader [4]byte + _, err = io.ReadFull(conn, responseHeader[:]) + require.NoError(t, err) + require.Equal(t, socks5Version, responseHeader[0]) + require.Equal(t, byte(succeeded), responseHeader[1]) + + // Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller). + switch addrType(responseHeader[3]) { + case ipv4: + var addrPort [net.IPv4len + 2]byte + _, err = io.ReadFull(conn, addrPort[:]) + require.NoError(t, err) + case ipv6: + var addrPort [net.IPv6len + 2]byte + _, err = io.ReadFull(conn, addrPort[:]) + require.NoError(t, err) + case domainName: + var lenBuf [1]byte + _, err = io.ReadFull(conn, lenBuf[:]) + require.NoError(t, err) + addrPort := make([]byte, int(lenBuf[0])+2) + _, err = io.ReadFull(conn, addrPort) + require.NoError(t, err) + } + + return conn +} + +func Test_newServer(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + settings Settings + expected *server + }{ + "with_auth": { + settings: Settings{ + Username: "user", + Password: "pass", + Address: "127.0.0.1:1080", + }, + expected: &server{ + username: "user", + password: "pass", + address: "127.0.0.1:1080", + }, + }, + "without_auth": { + settings: Settings{ + Address: "127.0.0.1:1080", + }, + expected: &server{ + address: "127.0.0.1:1080", + }, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + result := newServer(testCase.settings) + assert.Equal(t, testCase.expected.username, result.username) + assert.Equal(t, testCase.expected.password, result.password) + assert.Equal(t, testCase.expected.address, result.address) + assert.Equal(t, testCase.expected.logger, result.logger) + }) + } +} + +func Test_Server_StartStop(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + logger := NewMockLogger(ctrl) + logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any()) + + server := newServer(Settings{ + Address: "127.0.0.1:0", + Logger: logger, + }) + + runErr, startErr := server.Start(t.Context()) + require.NoError(t, startErr) + + select { + case err := <-runErr: + t.Fatalf("unexpected error on start: %v", err) + default: + } + + address := server.listeningAddress() + assert.NotNil(t, address) + + err := server.Stop() + require.NoError(t, err) +} + +func Test_encodeBindData(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + addrType addrType + address string + port uint16 + expectedErr string + }{ + "ipv4_valid": { + addrType: ipv4, + address: "127.0.0.1", + port: 8080, + }, + "ipv6_valid": { + addrType: ipv6, + address: "::1", + port: 8080, + }, + "domain_name_valid": { + addrType: domainName, + address: "example.com", + port: 8080, + }, + "ipv4_invalid": { + addrType: ipv4, + address: "invalid", + expectedErr: "parsing IP address", + }, + "ipv4_actual_ipv6": { + addrType: ipv4, + address: "::1", + expectedErr: "ip version is unexpected", + }, + "ipv6_actual_ipv4": { + addrType: ipv6, + address: "127.0.0.1", + expectedErr: "ip version is unexpected", + }, + "domain_too_long": { + addrType: domainName, + address: strings.Repeat("a", 256), + expectedErr: "domain name is too long", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + data, err := encodeBindData(testCase.addrType, testCase.address, testCase.port) + + if testCase.expectedErr != "" { + assert.ErrorContains(t, err, testCase.expectedErr) + assert.Nil(t, data) + } else { + assert.NoError(t, err) + assert.NotNil(t, data) + + assert.Equal(t, byte(testCase.addrType), data[0]) + + portOffset := len(data) - 2 + decodedPort := binary.BigEndian.Uint16(data[portOffset:]) + assert.Equal(t, testCase.port, decodedPort) + } + }) + } +} + +func Test_decodeRequest(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + packet []byte + expectedErr string + validate func(*testing.T, request) + }{ + "ipv4_valid": { + packet: []byte{socks5Version, byte(connect), 0, byte(ipv4), 127, 0, 0, 1, byte(0x1f), byte(0x90)}, + validate: func(t *testing.T, request request) { + t.Helper() + assert.Equal(t, connect, request.command) + assert.Equal(t, "127.0.0.1", request.destination) + assert.Equal(t, uint16(8080), request.port) + assert.Equal(t, ipv4, request.addressType) + }, + }, + "domain_name_valid": { + packet: concatBytes( + []byte{socks5Version, byte(connect), 0, byte(domainName)}, + []byte{byte(len("example.com"))}, + []byte("example.com"), + []byte{0x00, 0x50}, + ), + validate: func(t *testing.T, request request) { + t.Helper() + assert.Equal(t, "example.com", request.destination) + assert.Equal(t, uint16(80), request.port) + assert.Equal(t, domainName, request.addressType) + }, + }, + "version_mismatch": { + packet: []byte{4, byte(connect), 0, byte(ipv4), 127, 0, 0, 1, 0, 0}, + expectedErr: "version is not supported", + }, + "truncated_header": { + packet: []byte{socks5Version, byte(connect)}, + expectedErr: "reading header", + }, + "unsupported_address_type": { + packet: []byte{socks5Version, byte(connect), 0, byte(255)}, + expectedErr: "address type is not supported", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + reader := bytes.NewReader(testCase.packet) + + request, err := decodeRequest(reader, socks5Version) + + if testCase.expectedErr != "" { + assert.ErrorContains(t, err, testCase.expectedErr) + } else { + assert.NoError(t, err) + testCase.validate(t, request) + } + }) + } +} + +func Test_verifyFirstNegotiation(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + packet []byte + requiredAuth authMethod + expectedErr string + }{ + "version_mismatch": { + packet: []byte{4, 2, byte(authNotRequired), byte(authUsernamePassword)}, + requiredAuth: authNotRequired, + expectedErr: "version is not supported", + }, + "no_methods": { + packet: []byte{socks5Version, 0}, + requiredAuth: authNotRequired, + expectedErr: "no method identifiers", + }, + "required_method_not_present": { + packet: []byte{socks5Version, 2, byte(authNotRequired), byte(authGssapi)}, + requiredAuth: authUsernamePassword, + expectedErr: "no valid method identifier", + }, + "required_method_present": { + packet: []byte{socks5Version, 3, byte(authNotRequired), byte(authUsernamePassword), byte(authGssapi)}, + requiredAuth: authUsernamePassword, + }, + "no_auth_required": { + packet: []byte{socks5Version, 1, byte(authNotRequired)}, + requiredAuth: authNotRequired, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + reader := bytes.NewReader(testCase.packet) + + err := verifyFirstNegotiation(reader, testCase.requiredAuth) + + if testCase.expectedErr != "" { + assert.ErrorContains(t, err, testCase.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_usernamePasswordSubnegotiate(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + packet []byte + username string + password string + expectedErr string + }{ + "valid_credentials": { + packet: concatBytes( + []byte{authUsernamePasswordSubNegotiation1, 4}, + []byte("user"), + []byte{4}, + []byte("pass"), + ), + username: "user", + password: "pass", + }, + "version_mismatch": { + packet: []byte{2, 4, 'u', 's', 'e', 'r'}, + username: "user", + password: "pass", + expectedErr: "subnegotiation version not supported", + }, + "wrong_username": { + packet: concatBytes( + []byte{authUsernamePasswordSubNegotiation1, 4}, + []byte("fake"), + []byte{4}, + []byte("pass"), + ), + username: "user", + password: "pass", + expectedErr: "username received is not valid", + }, + "wrong_password": { + packet: concatBytes( + []byte{authUsernamePasswordSubNegotiation1, 4}, + []byte("user"), + []byte{4}, + []byte("fake"), + ), + username: "user", + password: "pass", + expectedErr: "password not valid", + }, + "truncated_header": { + packet: []byte{authUsernamePasswordSubNegotiation1}, + username: "user", + password: "pass", + expectedErr: "reading header", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + buffer := bytes.NewBuffer(testCase.packet) + + readWriter := struct { + io.Reader + io.Writer + }{ + Reader: buffer, + Writer: io.Discard, + } + + err := usernamePasswordSubnegotiate(readWriter, testCase.username, testCase.password) + + if testCase.expectedErr != "" { + assert.ErrorContains(t, err, testCase.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func concatBytes(slices ...[]byte) []byte { + var result []byte + for _, slice := range slices { + result = append(result, slice...) + } + return result +} + +func Test_bindDataLength(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + addrType addrType + address string + wantMaxLength uint + }{ + "ipv4": { + addrType: ipv4, + address: "127.0.0.1", + wantMaxLength: 1 + 4 + 2, + }, + "ipv6": { + addrType: ipv6, + address: "::1", + wantMaxLength: 1 + 16 + 2, + }, + "domain_short": { + addrType: domainName, + address: "example.com", + wantMaxLength: 1 + 1 + uint(len("example.com")) + 2, + }, + "domain_long": { + addrType: domainName, + address: strings.Repeat("a", 100), + wantMaxLength: 1 + 1 + 100 + 2, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + length := bindDataLength(testCase.addrType, testCase.address) + assert.Equal(t, testCase.wantMaxLength, length) + }) + } +} + +func Test_authMethod_String(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + method authMethod + expectedName string + }{ + "no_auth": { + method: authNotRequired, + expectedName: "no authentication required", + }, + "gssapi": { + method: authGssapi, + expectedName: "GSSAPI", + }, + "username_password": { + method: authUsernamePassword, + expectedName: "username/password", + }, + "not_acceptable": { + method: authNotAcceptable, + expectedName: "no acceptable methods", + }, + "unknown": { + method: authMethod(99), + expectedName: "unknown method (99)", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + result := testCase.method.String() + assert.Equal(t, testCase.expectedName, result) + }) + } +} + +func Test_cmdType_String(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + cmd cmdType + expectedName string + }{ + "connect": { + cmd: connect, + expectedName: "connect", + }, + "bind": { + cmd: bind, + expectedName: "bind", + }, + "udp_associate": { + cmd: udpAssociate, + expectedName: "UDP associate", + }, + "unknown": { + cmd: cmdType(99), + expectedName: "unknown command (99)", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + result := testCase.cmd.String() + assert.Equal(t, testCase.expectedName, result) + }) + } +} diff --git a/internal/socks5/usernamepassword.go b/internal/socks5/usernamepassword.go new file mode 100644 index 00000000..34be9835 --- /dev/null +++ b/internal/socks5/usernamepassword.go @@ -0,0 +1,62 @@ +package socks5 + +import ( + "fmt" + "io" +) + +// See https://datatracker.ietf.org/doc/html/rfc1929#section-2 +func usernamePasswordSubnegotiate(conn io.ReadWriter, username, password string) (err error) { + status := byte(1) + const defaultVersion = byte(1) + + const headerLength = 2 + var header [headerLength]byte + _, err = io.ReadFull(conn, header[:]) + if err != nil { + _, _ = conn.Write([]byte{defaultVersion, status}) + return fmt.Errorf("reading header: %w", err) + } + + if header[0] != authUsernamePasswordSubNegotiation1 { + _, _ = conn.Write([]byte{defaultVersion, status}) + return fmt.Errorf("subnegotiation version not supported: %d", header[0]) + } + version := header[0] + + usernameBytes := make([]byte, header[1]) + _, err = io.ReadFull(conn, usernameBytes) + if err != nil { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("reading username bytes: %w", err) + } else if username != string(usernameBytes) { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("username received is not valid") + } + + const passwordHeaderLength = 1 + passwordHeader := make([]byte, passwordHeaderLength) + _, err = io.ReadFull(conn, passwordHeader) + if err != nil { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("reading password length: %w", err) + } + + passwordBytes := make([]byte, passwordHeader[0]) + _, err = io.ReadFull(conn, passwordBytes) + if err != nil { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("reading password bytes: %w", err) + } else if password != string(passwordBytes) { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("password not valid for username %q", string(usernameBytes)) + } + + status = 0 + _, err = conn.Write([]byte{version, status}) + if err != nil { + return fmt.Errorf("writing success status: %w", err) + } + + return nil +}