mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-13 07:42:24 +02:00
feat: socks5 proxy server (#3336)
- `SOCKS5_ENABLED=off` - `SOCKS5_LISTENING_ADDRESS=":1080"` - `SOCKS5_USER=` - `SOCKS5_PASSWORD=`
This commit is contained in:
+6
-1
@@ -240,6 +240,11 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
SHADOWSOCKS_PASSWORD= \
|
SHADOWSOCKS_PASSWORD= \
|
||||||
SHADOWSOCKS_PASSWORD_SECRETFILE=/run/secrets/shadowsocks_password \
|
SHADOWSOCKS_PASSWORD_SECRETFILE=/run/secrets/shadowsocks_password \
|
||||||
SHADOWSOCKS_CIPHER=chacha20-ietf-poly1305 \
|
SHADOWSOCKS_CIPHER=chacha20-ietf-poly1305 \
|
||||||
|
# Socks5
|
||||||
|
SOCKS5_ENABLED=off \
|
||||||
|
SOCKS5_LISTENING_ADDRESS=":1080" \
|
||||||
|
SOCKS5_USER= \
|
||||||
|
SOCKS5_PASSWORD= \
|
||||||
# Control server
|
# Control server
|
||||||
HTTP_CONTROL_SERVER_LOG=on \
|
HTTP_CONTROL_SERVER_LOG=on \
|
||||||
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
||||||
@@ -271,7 +276,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
PUID=1000 \
|
PUID=1000 \
|
||||||
PGID=1000
|
PGID=1000
|
||||||
ENTRYPOINT ["/gluetun-entrypoint"]
|
ENTRYPOINT ["/gluetun-entrypoint"]
|
||||||
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp
|
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp
|
||||||
HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck
|
HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
RUN apk add --no-cache --update -l wget && \
|
RUN apk add --no-cache --update -l wget && \
|
||||||
|
|||||||
+16
-1
@@ -41,6 +41,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/routing"
|
"github.com/qdm12/gluetun/internal/routing"
|
||||||
"github.com/qdm12/gluetun/internal/server"
|
"github.com/qdm12/gluetun/internal/server"
|
||||||
"github.com/qdm12/gluetun/internal/shadowsocks"
|
"github.com/qdm12/gluetun/internal/shadowsocks"
|
||||||
|
"github.com/qdm12/gluetun/internal/socks5"
|
||||||
"github.com/qdm12/gluetun/internal/storage"
|
"github.com/qdm12/gluetun/internal/storage"
|
||||||
updater "github.com/qdm12/gluetun/internal/updater/loop"
|
updater "github.com/qdm12/gluetun/internal/updater/loop"
|
||||||
"github.com/qdm12/gluetun/internal/updater/resolver"
|
"github.com/qdm12/gluetun/internal/updater/resolver"
|
||||||
@@ -411,6 +412,18 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
return fmt.Errorf("starting public ip loop: %w", err)
|
return fmt.Errorf("starting public ip loop: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
socks5Loop := socks5.NewLoop(socks5.Settings{
|
||||||
|
Enabled: *allSettings.Socks5.Enabled,
|
||||||
|
Username: *allSettings.Socks5.Username,
|
||||||
|
Password: *allSettings.Socks5.Password,
|
||||||
|
Address: allSettings.Socks5.ListeningAddress,
|
||||||
|
Logger: logger.New(log.SetComponent("socks5")),
|
||||||
|
})
|
||||||
|
socks5RunError, err := socks5Loop.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting SOCKS5 server loop: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
healthLogger := logger.New(log.SetComponent("healthcheck"))
|
healthLogger := logger.New(log.SetComponent("healthcheck"))
|
||||||
healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger)
|
healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger)
|
||||||
healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler(
|
healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler(
|
||||||
@@ -506,7 +519,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
String() string
|
String() string
|
||||||
Stop() error
|
Stop() error
|
||||||
}{
|
}{
|
||||||
portForwardLooper, publicIPLooper,
|
portForwardLooper, publicIPLooper, socks5Loop,
|
||||||
}
|
}
|
||||||
for _, stopper := range stoppers {
|
for _, stopper := range stoppers {
|
||||||
err := stopper.Stop()
|
err := stopper.Stop()
|
||||||
@@ -518,6 +531,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
logger.Errorf("port forwarding loop crashed: %s", err)
|
logger.Errorf("port forwarding loop crashed: %s", err)
|
||||||
case err := <-publicIPRunError:
|
case err := <-publicIPRunError:
|
||||||
logger.Errorf("public IP loop crashed: %s", err)
|
logger.Errorf("public IP loop crashed: %s", err)
|
||||||
|
case err := <-socks5RunError:
|
||||||
|
logger.Errorf("SOCKS5 server loop crashed: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return orderHandler.Shutdown(context.Background())
|
return orderHandler.Shutdown(context.Background())
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type Settings struct {
|
|||||||
HTTPProxy HTTPProxy
|
HTTPProxy HTTPProxy
|
||||||
Log Log
|
Log Log
|
||||||
PublicIP PublicIP
|
PublicIP PublicIP
|
||||||
|
Socks5 Socks5
|
||||||
Shadowsocks Shadowsocks
|
Shadowsocks Shadowsocks
|
||||||
Storage Storage
|
Storage Storage
|
||||||
System System
|
System System
|
||||||
@@ -49,6 +50,7 @@ func (s *Settings) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Support
|
|||||||
"http proxy": s.HTTPProxy.validate,
|
"http proxy": s.HTTPProxy.validate,
|
||||||
"log": s.Log.validate,
|
"log": s.Log.validate,
|
||||||
"public ip check": s.PublicIP.validate,
|
"public ip check": s.PublicIP.validate,
|
||||||
|
"socks5": s.Socks5.validate,
|
||||||
"shadowsocks": s.Shadowsocks.validate,
|
"shadowsocks": s.Shadowsocks.validate,
|
||||||
"storage": s.Storage.validate,
|
"storage": s.Storage.validate,
|
||||||
"system": s.System.validate,
|
"system": s.System.validate,
|
||||||
@@ -81,6 +83,7 @@ func (s *Settings) copy() (copied Settings) {
|
|||||||
HTTPProxy: s.HTTPProxy.copy(),
|
HTTPProxy: s.HTTPProxy.copy(),
|
||||||
Log: s.Log.copy(),
|
Log: s.Log.copy(),
|
||||||
PublicIP: s.PublicIP.copy(),
|
PublicIP: s.PublicIP.copy(),
|
||||||
|
Socks5: s.Socks5.copy(),
|
||||||
Shadowsocks: s.Shadowsocks.copy(),
|
Shadowsocks: s.Shadowsocks.copy(),
|
||||||
Storage: s.Storage.copy(),
|
Storage: s.Storage.copy(),
|
||||||
System: s.System.copy(),
|
System: s.System.copy(),
|
||||||
@@ -104,6 +107,7 @@ func (s *Settings) OverrideWith(other Settings,
|
|||||||
patchedSettings.HTTPProxy.overrideWith(other.HTTPProxy)
|
patchedSettings.HTTPProxy.overrideWith(other.HTTPProxy)
|
||||||
patchedSettings.Log.overrideWith(other.Log)
|
patchedSettings.Log.overrideWith(other.Log)
|
||||||
patchedSettings.PublicIP.overrideWith(other.PublicIP)
|
patchedSettings.PublicIP.overrideWith(other.PublicIP)
|
||||||
|
patchedSettings.Socks5.overrideWith(other.Socks5)
|
||||||
patchedSettings.Shadowsocks.overrideWith(other.Shadowsocks)
|
patchedSettings.Shadowsocks.overrideWith(other.Shadowsocks)
|
||||||
patchedSettings.Storage.overrideWith(other.Storage)
|
patchedSettings.Storage.overrideWith(other.Storage)
|
||||||
patchedSettings.System.overrideWith(other.System)
|
patchedSettings.System.overrideWith(other.System)
|
||||||
@@ -131,6 +135,7 @@ func (s *Settings) SetDefaults() {
|
|||||||
s.Log.setDefaults()
|
s.Log.setDefaults()
|
||||||
s.IPv6.setDefaults()
|
s.IPv6.setDefaults()
|
||||||
s.PublicIP.setDefaults()
|
s.PublicIP.setDefaults()
|
||||||
|
s.Socks5.setDefaults()
|
||||||
s.Shadowsocks.setDefaults()
|
s.Shadowsocks.setDefaults()
|
||||||
s.Storage.SetDefaults()
|
s.Storage.SetDefaults()
|
||||||
s.System.setDefaults()
|
s.System.setDefaults()
|
||||||
@@ -154,6 +159,7 @@ func (s Settings) toLinesNode() (node *gotree.Node) {
|
|||||||
node.AppendNode(s.Log.toLinesNode())
|
node.AppendNode(s.Log.toLinesNode())
|
||||||
node.AppendNode(s.IPv6.toLinesNode())
|
node.AppendNode(s.IPv6.toLinesNode())
|
||||||
node.AppendNode(s.Health.toLinesNode())
|
node.AppendNode(s.Health.toLinesNode())
|
||||||
|
node.AppendNode(s.Socks5.toLinesNode())
|
||||||
node.AppendNode(s.Shadowsocks.toLinesNode())
|
node.AppendNode(s.Shadowsocks.toLinesNode())
|
||||||
node.AppendNode(s.HTTPProxy.toLinesNode())
|
node.AppendNode(s.HTTPProxy.toLinesNode())
|
||||||
node.AppendNode(s.ControlServer.toLinesNode())
|
node.AppendNode(s.ControlServer.toLinesNode())
|
||||||
@@ -212,6 +218,7 @@ func (s *Settings) Read(r *reader.Reader, warner Warner) (err error) {
|
|||||||
"public ip": func(r *reader.Reader) error {
|
"public ip": func(r *reader.Reader) error {
|
||||||
return s.PublicIP.read(r, warner)
|
return s.PublicIP.read(r, warner)
|
||||||
},
|
},
|
||||||
|
"socks5": s.Socks5.read,
|
||||||
"shadowsocks": s.Shadowsocks.read,
|
"shadowsocks": s.Shadowsocks.read,
|
||||||
"storage": s.Storage.Read,
|
"storage": s.Storage.Read,
|
||||||
"system": s.System.read,
|
"system": s.System.read,
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| | ├── 1.1.1.1
|
| | ├── 1.1.1.1
|
||||||
| | └── 8.8.8.8
|
| | └── 8.8.8.8
|
||||||
| └── Restart VPN on healthcheck failure: yes
|
| └── Restart VPN on healthcheck failure: yes
|
||||||
|
├── SOCKS5 proxy server settings:
|
||||||
|
| └── Enabled: no
|
||||||
├── Shadowsocks server settings:
|
├── Shadowsocks server settings:
|
||||||
| └── Enabled: no
|
| └── Enabled: no
|
||||||
├── HTTP proxy settings:
|
├── HTTP proxy settings:
|
||||||
|
|||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/qdm12/gosettings"
|
||||||
|
"github.com/qdm12/gosettings/reader"
|
||||||
|
"github.com/qdm12/gosettings/validate"
|
||||||
|
"github.com/qdm12/gotree"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Socks5 contains settings to configure the Socks5 proxy server.
|
||||||
|
type Socks5 struct {
|
||||||
|
Enabled *bool
|
||||||
|
ListeningAddress string
|
||||||
|
Username *string
|
||||||
|
Password *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Socks5) validate() (err error) {
|
||||||
|
err = validate.ListeningAddress(s.ListeningAddress, os.Getuid())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case *s.Username != "" && *s.Password == "":
|
||||||
|
return errors.New("password must be set if username is set")
|
||||||
|
case *s.Username == "" && *s.Password != "":
|
||||||
|
return errors.New("username must be set if password is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socks5) copy() (copied Socks5) {
|
||||||
|
return Socks5{
|
||||||
|
Enabled: gosettings.CopyPointer(s.Enabled),
|
||||||
|
ListeningAddress: s.ListeningAddress,
|
||||||
|
Username: gosettings.CopyPointer(s.Username),
|
||||||
|
Password: gosettings.CopyPointer(s.Password),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socks5) overrideWith(other Socks5) {
|
||||||
|
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, other.Enabled)
|
||||||
|
s.ListeningAddress = gosettings.OverrideWithComparable(s.ListeningAddress, other.ListeningAddress)
|
||||||
|
s.Username = gosettings.OverrideWithPointer(s.Username, other.Username)
|
||||||
|
s.Password = gosettings.OverrideWithPointer(s.Password, other.Password)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socks5) setDefaults() {
|
||||||
|
s.Enabled = gosettings.DefaultPointer(s.Enabled, false)
|
||||||
|
s.ListeningAddress = gosettings.DefaultComparable(s.ListeningAddress, ":1080")
|
||||||
|
s.Username = gosettings.DefaultPointer(s.Username, "")
|
||||||
|
s.Password = gosettings.DefaultPointer(s.Password, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Socks5) String() string {
|
||||||
|
return s.toLinesNode().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Socks5) toLinesNode() (node *gotree.Node) {
|
||||||
|
node = gotree.New("SOCKS5 proxy server settings:")
|
||||||
|
node.Appendf("Enabled: %s", gosettings.BoolToYesNo(s.Enabled))
|
||||||
|
if !*s.Enabled {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
node.Appendf("Listening address: %s", s.ListeningAddress)
|
||||||
|
if *s.Username != "" || *s.Password != "" {
|
||||||
|
node.Appendf("Username: %s", *s.Username)
|
||||||
|
node.Appendf("Password: %s", gosettings.ObfuscateKey(*s.Password))
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socks5) read(r *reader.Reader) (err error) {
|
||||||
|
s.Enabled, err = r.BoolPtr("SOCKS5_ENABLED")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ListeningAddress = r.String("SOCKS5_LISTENING_ADDRESS")
|
||||||
|
s.Username = r.Get("SOCKS5_USER", reader.ForceLowercase(false))
|
||||||
|
s.Password = r.Get("SOCKS5_PASSWORD", reader.ForceLowercase(false))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
|
||||||
|
type authMethod byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
authNotRequired authMethod = 0
|
||||||
|
authGssapi authMethod = 1
|
||||||
|
authUsernamePassword authMethod = 2
|
||||||
|
authNotAcceptable authMethod = 255
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a authMethod) String() string {
|
||||||
|
switch a {
|
||||||
|
case authNotRequired:
|
||||||
|
return "no authentication required"
|
||||||
|
case authGssapi:
|
||||||
|
return "GSSAPI"
|
||||||
|
case authUsernamePassword:
|
||||||
|
return "username/password"
|
||||||
|
case authNotAcceptable:
|
||||||
|
return "no acceptable methods"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown method (%d)", a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subnegotiation version
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
||||||
|
const (
|
||||||
|
authUsernamePasswordSubNegotiation1 byte = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// SOCKS versions.
|
||||||
|
const (
|
||||||
|
socks5Version byte = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
|
||||||
|
type cmdType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
connect cmdType = 1
|
||||||
|
bind cmdType = 2
|
||||||
|
udpAssociate cmdType = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c cmdType) String() string {
|
||||||
|
switch c {
|
||||||
|
case connect:
|
||||||
|
return "connect"
|
||||||
|
case bind:
|
||||||
|
return "bind"
|
||||||
|
case udpAssociate:
|
||||||
|
return "UDP associate"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown command (%d)", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 and
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1928#section-5
|
||||||
|
type addrType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4 addrType = 1
|
||||||
|
domainName addrType = 3
|
||||||
|
ipv6 addrType = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
|
||||||
|
type replyCode byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
succeeded replyCode = iota
|
||||||
|
generalServerFailure
|
||||||
|
connectionNotAllowedByRuleset
|
||||||
|
networkUnreachable
|
||||||
|
hostUnreachable
|
||||||
|
connectionRefused
|
||||||
|
ttlExpired
|
||||||
|
commandNotSupported
|
||||||
|
addressTypeNotSupported
|
||||||
|
)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Infof(format string, a ...interface{})
|
||||||
|
Warnf(format string, a ...interface{})
|
||||||
|
}
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/goservices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Loop struct {
|
||||||
|
settings Settings
|
||||||
|
|
||||||
|
mutex sync.Mutex
|
||||||
|
runCancel context.CancelFunc
|
||||||
|
runDone <-chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLoop(settings Settings) *Loop {
|
||||||
|
return &Loop{
|
||||||
|
settings: settings,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Loop) String() string {
|
||||||
|
return "SOCKS5 server loop"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Loop) Start(_ context.Context) (runError <-chan error, err error) {
|
||||||
|
l.mutex.Lock()
|
||||||
|
defer l.mutex.Unlock()
|
||||||
|
|
||||||
|
var runCtx context.Context
|
||||||
|
runCtx, l.runCancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
runDone := make(chan error)
|
||||||
|
l.runDone = runDone
|
||||||
|
|
||||||
|
go run(runCtx, runDone, l.settings)
|
||||||
|
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(ctx context.Context, done chan<- error, settings Settings) {
|
||||||
|
defer close(done)
|
||||||
|
logger := settings.Logger
|
||||||
|
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
var server goservices.Service
|
||||||
|
if settings.Enabled {
|
||||||
|
server = newServer(settings)
|
||||||
|
} else {
|
||||||
|
server = new(noopService)
|
||||||
|
}
|
||||||
|
|
||||||
|
errorCh, err := server.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("failed starting SOCKS5 server: %s", err)
|
||||||
|
waitBeforeRetry(ctx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
done <- server.Stop()
|
||||||
|
return
|
||||||
|
case err := <-errorCh:
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Warnf("SOCKS5 server crashed: %s", err)
|
||||||
|
waitBeforeRetry(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Loop) Stop() (err error) {
|
||||||
|
l.mutex.Lock()
|
||||||
|
defer l.mutex.Unlock()
|
||||||
|
|
||||||
|
l.runCancel()
|
||||||
|
return <-l.runDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitBeforeRetry(ctx context.Context) {
|
||||||
|
const retryDelay = 10 * time.Second
|
||||||
|
timer := time.NewTimer(retryDelay)
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopService struct{}
|
||||||
|
|
||||||
|
func (s noopService) Start(_ context.Context) (runErr <-chan error, err error) {
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s noopService) Stop() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s noopService) String() string {
|
||||||
|
return "noop service"
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/socks5 (interfaces: Logger)
|
||||||
|
|
||||||
|
// Package socks5 is a generated GoMock package.
|
||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLogger is a mock of Logger interface.
|
||||||
|
type MockLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||||
|
type MockLoggerMockRecorder struct {
|
||||||
|
mock *MockLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLogger creates a new mock instance.
|
||||||
|
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
mock := &MockLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infof mocks base method.
|
||||||
|
func (m *MockLogger) Infof(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Infof", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infof indicates an expected call of Infof.
|
||||||
|
func (mr *MockLoggerMockRecorder) Infof(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf mocks base method.
|
||||||
|
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Warnf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf indicates an expected call of Warnf.
|
||||||
|
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
|
||||||
|
}
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
|
||||||
|
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { //nolint:unparam
|
||||||
|
_, err := writer.Write([]byte{
|
||||||
|
socksVersion,
|
||||||
|
byte(reply),
|
||||||
|
0, // RSV byte
|
||||||
|
// The RFC requires a full response frame even for failures.
|
||||||
|
// Use IPv4 address type with zeroed address and port.
|
||||||
|
byte(ipv4), // ATYP
|
||||||
|
0, 0, 0, 0, // BND.ADDR (zeroed)
|
||||||
|
0, 0, // BND.PORT (zeroed)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Warnf("failed writing failed response: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
|
||||||
|
func (c *socksConn) encodeSuccessResponse(writer io.Writer, socksVersion byte,
|
||||||
|
reply replyCode, bindAddrType addrType, bindAddress string,
|
||||||
|
bindPort uint16,
|
||||||
|
) (err error) {
|
||||||
|
bindData, err := encodeBindData(bindAddrType, bindAddress, bindPort)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encoding bind data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialPacketLength = 3
|
||||||
|
capacity := initialPacketLength + len(bindData)
|
||||||
|
packet := make([]byte, initialPacketLength, capacity)
|
||||||
|
packet[0] = socksVersion
|
||||||
|
packet[1] = byte(reply)
|
||||||
|
packet[2] = 0 // RSV byte
|
||||||
|
packet = append(packet, bindData...)
|
||||||
|
|
||||||
|
_, err = writer.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing packet: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrIPVersionUnexpected = errors.New("ip version is unexpected")
|
||||||
|
ErrDomainNameTooLong = errors.New("domain name is too long")
|
||||||
|
)
|
||||||
|
|
||||||
|
func encodeBindData(addrType addrType, address string, port uint16) (
|
||||||
|
data []byte, err error,
|
||||||
|
) {
|
||||||
|
capacity := bindDataLength(addrType, address)
|
||||||
|
data = make([]byte, 0, capacity)
|
||||||
|
|
||||||
|
data = append(data, byte(addrType))
|
||||||
|
switch addrType {
|
||||||
|
case ipv4, ipv6:
|
||||||
|
ip, err := netip.ParseAddr(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing IP address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case addrType == ipv4 && !ip.Is4():
|
||||||
|
return nil, fmt.Errorf("%w: expected IPv4 for %s", ErrIPVersionUnexpected, ip)
|
||||||
|
case addrType == ipv6 && !ip.Is6():
|
||||||
|
return nil, fmt.Errorf("%w: expected IPv6 for %s", ErrIPVersionUnexpected, ip)
|
||||||
|
}
|
||||||
|
data = append(data, ip.AsSlice()...)
|
||||||
|
case domainName:
|
||||||
|
const maxDomainNameLength = 255
|
||||||
|
if len(address) > maxDomainNameLength {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrDomainNameTooLong, address)
|
||||||
|
}
|
||||||
|
data = append(data, byte(len(address)))
|
||||||
|
data = append(data, []byte(address)...)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unsupported address type %d", addrType))
|
||||||
|
}
|
||||||
|
data = binary.BigEndian.AppendUint16(data, port)
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindDataLength(addrType addrType, address string) (maxLength uint) {
|
||||||
|
maxLength++ // address type
|
||||||
|
switch addrType {
|
||||||
|
case ipv4:
|
||||||
|
maxLength += net.IPv4len
|
||||||
|
case domainName:
|
||||||
|
maxLength++ // domain name length
|
||||||
|
maxLength += uint(len([]byte(address)))
|
||||||
|
case ipv6:
|
||||||
|
maxLength += net.IPv6len
|
||||||
|
default:
|
||||||
|
panic("unsupported address type: " + fmt.Sprint(addrType))
|
||||||
|
}
|
||||||
|
maxLength += 2 // port
|
||||||
|
return maxLength
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type server struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
address string
|
||||||
|
logger Logger
|
||||||
|
|
||||||
|
// internal fields
|
||||||
|
listener net.Listener
|
||||||
|
listening atomic.Bool
|
||||||
|
socksConnCtx context.Context //nolint:containedctx
|
||||||
|
socksConnCancel context.CancelFunc
|
||||||
|
done <-chan struct{}
|
||||||
|
stopping atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServer(settings Settings) *server {
|
||||||
|
return &server{
|
||||||
|
username: settings.Username,
|
||||||
|
password: settings.Password,
|
||||||
|
address: settings.Address,
|
||||||
|
logger: settings.Logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) String() string {
|
||||||
|
return "SOCKS5 server"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
|
||||||
|
s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background())
|
||||||
|
config := &net.ListenConfig{}
|
||||||
|
s.listener, err = config.Listen(ctx, "tcp", s.address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listening on %s: %w", s.address, err)
|
||||||
|
}
|
||||||
|
s.listening.Store(true)
|
||||||
|
s.logger.Infof("SOCKS5 server listening on %s", s.listener.Addr())
|
||||||
|
|
||||||
|
ready := make(chan struct{})
|
||||||
|
runErrCh := make(chan error)
|
||||||
|
runErr = runErrCh
|
||||||
|
done := make(chan struct{})
|
||||||
|
s.done = done
|
||||||
|
go s.runServer(ready, runErrCh, done)
|
||||||
|
select {
|
||||||
|
case <-ready:
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = s.Stop()
|
||||||
|
return nil, fmt.Errorf("starting server: %w", ctx.Err())
|
||||||
|
}
|
||||||
|
return runErr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) runServer(ready chan<- struct{},
|
||||||
|
runErrCh chan<- error, done chan<- struct{},
|
||||||
|
) {
|
||||||
|
close(ready)
|
||||||
|
defer close(done)
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
for {
|
||||||
|
connection, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if !s.stopping.Load() {
|
||||||
|
_ = s.stop()
|
||||||
|
runErrCh <- fmt.Errorf("accepting connection: %w", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ctx context.Context, connection net.Conn,
|
||||||
|
dialer *net.Dialer, wg *sync.WaitGroup,
|
||||||
|
) {
|
||||||
|
defer wg.Done()
|
||||||
|
socksConn := &socksConn{
|
||||||
|
dialer: dialer,
|
||||||
|
username: s.username,
|
||||||
|
password: s.password,
|
||||||
|
clientConn: connection,
|
||||||
|
logger: s.logger,
|
||||||
|
}
|
||||||
|
err := socksConn.run(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Infof("running socks connection: %s", err)
|
||||||
|
}
|
||||||
|
}(s.socksConnCtx, connection, dialer, wg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) Stop() (err error) {
|
||||||
|
s.stopping.Store(true)
|
||||||
|
err = s.stop()
|
||||||
|
<-s.done // wait for run goroutine to finish
|
||||||
|
s.stopping.Store(false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) stop() error {
|
||||||
|
s.listening.Store(false)
|
||||||
|
err := s.listener.Close()
|
||||||
|
s.socksConnCancel() // stop ongoing socks connections
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) listeningAddress() net.Addr {
|
||||||
|
if s.listening.Load() {
|
||||||
|
return s.listener.Addr()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
type Settings struct {
|
||||||
|
Enabled bool
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
Address string
|
||||||
|
Logger Logger
|
||||||
|
}
|
||||||
@@ -0,0 +1,290 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNoMethodIdentifiers = errors.New("no method identifiers")
|
||||||
|
errNoValidMethodIdentifier = errors.New("no valid method identifier")
|
||||||
|
)
|
||||||
|
|
||||||
|
type socksConn struct {
|
||||||
|
// Injected fields
|
||||||
|
dialer *net.Dialer
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
clientConn net.Conn
|
||||||
|
logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socksConn) closeClientConn(ctxErr error) {
|
||||||
|
err := c.clientConn.Close()
|
||||||
|
if err != nil && ctxErr == nil {
|
||||||
|
c.logger.Warnf("closing client connection: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socksConn) run(ctx context.Context) error {
|
||||||
|
// Monitoring context cancellation to close the connection and stop
|
||||||
|
// reading operations on clientConn.
|
||||||
|
done := make(chan struct{})
|
||||||
|
ctxWatcherDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(ctxWatcherDone)
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
// unblock read operations
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
close(done)
|
||||||
|
<-ctxWatcherDone
|
||||||
|
}()
|
||||||
|
|
||||||
|
authMethod := authNotRequired
|
||||||
|
if c.username != "" || c.password != "" {
|
||||||
|
authMethod = authUsernamePassword
|
||||||
|
}
|
||||||
|
|
||||||
|
err := verifyFirstNegotiation(c.clientConn, authMethod)
|
||||||
|
if err != nil {
|
||||||
|
replyMethod := authMethod
|
||||||
|
if errors.Is(err, errNoMethodIdentifiers) || errors.Is(err, errNoValidMethodIdentifier) {
|
||||||
|
replyMethod = authNotAcceptable
|
||||||
|
}
|
||||||
|
_, writeErr := c.clientConn.Write([]byte{socks5Version, byte(replyMethod)})
|
||||||
|
if writeErr != nil {
|
||||||
|
c.logger.Warnf("failed writing first negotiation reply: %s", writeErr)
|
||||||
|
}
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
return fmt.Errorf("verifying first negotiation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.clientConn.Write([]byte{socks5Version, byte(authMethod)})
|
||||||
|
if err != nil {
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
return fmt.Errorf("writing first negotiation reply: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch authMethod {
|
||||||
|
case authNotRequired, authNotAcceptable:
|
||||||
|
case authGssapi:
|
||||||
|
panic("not implemented")
|
||||||
|
case authUsernamePassword:
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
||||||
|
err = usernamePasswordSubnegotiate(c.clientConn, c.username, c.password)
|
||||||
|
if err != nil {
|
||||||
|
// If the server returns a `failure' (STATUS value other than X'00') status,
|
||||||
|
// it MUST close the connection.
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
return fmt.Errorf("subnegotiating username and password: %w", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unimplemented auth method %d", authMethod))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.handleRequest(ctx)
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("handling request: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socksConn) handleRequest(ctx context.Context) error {
|
||||||
|
const socksVersion = socks5Version
|
||||||
|
request, err := decodeRequest(c.clientConn, socksVersion)
|
||||||
|
if err != nil {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if request.command != connect {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
|
||||||
|
return fmt.Errorf("command %s is not supported", request.command)
|
||||||
|
}
|
||||||
|
|
||||||
|
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
|
||||||
|
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
|
||||||
|
if err != nil {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer destinationConn.Close()
|
||||||
|
|
||||||
|
destinationServerAddress := destinationConn.LocalAddr().String()
|
||||||
|
destinationAddr, destinationPortStr, err := net.SplitHostPort(destinationServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("splitting destination address: %w", err)
|
||||||
|
}
|
||||||
|
destinationPort, err := strconv.ParseUint(destinationPortStr, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("port is malformed: %q", destinationPortStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var bindAddrType addrType
|
||||||
|
if ip := net.ParseIP(destinationAddr); ip != nil {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
bindAddrType = ipv4
|
||||||
|
} else {
|
||||||
|
bindAddrType = ipv6
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bindAddrType = domainName
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, bindAddrType,
|
||||||
|
destinationAddr, uint16(destinationPort))
|
||||||
|
if err != nil {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||||
|
return fmt.Errorf("writing successful %s response: %w", request.command, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const capacity = 2 // if one goroutine fails, we don't want to leak the other one
|
||||||
|
errc := make(chan error, capacity)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(c.clientConn, destinationConn)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("from backend to client: %w", err)
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(destinationConn, c.clientConn)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("from client to backend: %w", err)
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <-errc:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = destinationConn.Close()
|
||||||
|
_ = c.clientConn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
|
||||||
|
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
|
||||||
|
const headerLength = 2 // version + nMethods bytes
|
||||||
|
header := make([]byte, headerLength)
|
||||||
|
_, err := io.ReadFull(reader, header)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header[0] != socks5Version {
|
||||||
|
return fmt.Errorf("version is not supported: %d", header[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
nMethods := header[1]
|
||||||
|
if nMethods == 0 {
|
||||||
|
return fmt.Errorf("%w", errNoMethodIdentifiers)
|
||||||
|
}
|
||||||
|
|
||||||
|
methodIdentifiers := make([]byte, nMethods)
|
||||||
|
_, err = io.ReadFull(reader, methodIdentifiers)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading method identifiers: %w", err)
|
||||||
|
}
|
||||||
|
for _, methodIdentifier := range methodIdentifiers {
|
||||||
|
if methodIdentifier == byte(requiredMethod) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return makeNoAcceptableMethodError(requiredMethod, methodIdentifiers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeNoAcceptableMethodError(requiredAuthMethod authMethod, methodIdentifiers []byte) error {
|
||||||
|
methodNames := make([]string, len(methodIdentifiers))
|
||||||
|
for i, methodIdentifier := range methodIdentifiers {
|
||||||
|
methodNames[i] = fmt.Sprintf("%q", authMethod(methodIdentifier))
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("%w: none of %s matches %s",
|
||||||
|
errNoValidMethodIdentifier, strings.Join(methodNames, ", "),
|
||||||
|
requiredAuthMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
|
||||||
|
type request struct {
|
||||||
|
command cmdType
|
||||||
|
destination string
|
||||||
|
port uint16
|
||||||
|
addressType addrType
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeRequest(reader io.Reader, expectedVersion byte) (req request, err error) {
|
||||||
|
const headerLength = 4
|
||||||
|
header := [headerLength]byte{}
|
||||||
|
_, err = io.ReadFull(reader, header[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
version := header[0]
|
||||||
|
switch {
|
||||||
|
case version != expectedVersion:
|
||||||
|
return request{}, fmt.Errorf("version is not supported: expected %d and got %d",
|
||||||
|
expectedVersion, version)
|
||||||
|
case header[2] != 0:
|
||||||
|
return request{}, fmt.Errorf("reserved header byte must be 0 but got %d", header[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
req.command = cmdType(header[1])
|
||||||
|
// header[2] is RSV byte
|
||||||
|
req.addressType = addrType(header[3])
|
||||||
|
|
||||||
|
switch req.addressType {
|
||||||
|
case ipv4:
|
||||||
|
var ip [4]byte
|
||||||
|
_, err = io.ReadFull(reader, ip[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading IPv4 address: %w", err)
|
||||||
|
}
|
||||||
|
req.destination = netip.AddrFrom4(ip).String()
|
||||||
|
case ipv6:
|
||||||
|
var ip [16]byte
|
||||||
|
_, err = io.ReadFull(reader, ip[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading IPv6 address: %w", err)
|
||||||
|
}
|
||||||
|
req.destination = netip.AddrFrom16(ip).String()
|
||||||
|
case domainName:
|
||||||
|
var header [1]byte
|
||||||
|
_, err = io.ReadFull(reader, header[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading domain name header: %w", err)
|
||||||
|
}
|
||||||
|
domainName := make([]byte, header[0])
|
||||||
|
_, err = io.ReadFull(reader, domainName)
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading domain name bytes: %w", err)
|
||||||
|
}
|
||||||
|
req.destination = string(domainName)
|
||||||
|
default:
|
||||||
|
return request{}, fmt.Errorf("address type is not supported: %d", req.addressType)
|
||||||
|
}
|
||||||
|
|
||||||
|
var portBytes [2]byte
|
||||||
|
_, err = io.ReadFull(reader, portBytes[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading port: %w", err)
|
||||||
|
}
|
||||||
|
req.port = binary.BigEndian.Uint16(portBytes[:])
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,622 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type noopLogger struct{}
|
||||||
|
|
||||||
|
func (noopLogger) Infof(string, ...any) {}
|
||||||
|
func (noopLogger) Warnf(string, ...any) {}
|
||||||
|
|
||||||
|
func TestServerProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}{
|
||||||
|
"no_auth": {},
|
||||||
|
"with_auth": {
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Backend TCP server: accepts one connection for the proxy to forward to.
|
||||||
|
backendListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
backendConnCh := make(chan net.Conn)
|
||||||
|
go func() {
|
||||||
|
conn, err := backendListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
backendConnCh <- conn
|
||||||
|
}()
|
||||||
|
|
||||||
|
server := newServer(Settings{
|
||||||
|
Username: testCase.username,
|
||||||
|
Password: testCase.password,
|
||||||
|
Address: "127.0.0.1:0",
|
||||||
|
Logger: noopLogger{},
|
||||||
|
})
|
||||||
|
_, err = server.Start(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
_ = backendListener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Dial through the SOCKS5 proxy to the backend.
|
||||||
|
// By the time dialSOCKS5 returns, the SOCKS5 server has already
|
||||||
|
// established the TCP connection to the backend, so backendConnCh
|
||||||
|
// is guaranteed to be populated.
|
||||||
|
clientConn := dialSOCKS5(t, server.listeningAddress().String(),
|
||||||
|
backendListener.Addr().String(), testCase.username, testCase.password)
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
backendConn := <-backendConnCh
|
||||||
|
defer backendConn.Close()
|
||||||
|
|
||||||
|
// Verify client → backend direction.
|
||||||
|
clientMessage := []byte("hello from client")
|
||||||
|
_, err = clientConn.Write(clientMessage)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
received := make([]byte, len(clientMessage))
|
||||||
|
_, err = io.ReadFull(backendConn, received)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, clientMessage, received)
|
||||||
|
|
||||||
|
// Verify backend → client direction.
|
||||||
|
backendMessage := []byte("hello from backend")
|
||||||
|
_, err = backendConn.Write(backendMessage)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
receivedByClient := make([]byte, len(backendMessage))
|
||||||
|
_, err = io.ReadFull(clientConn, receivedByClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, backendMessage, receivedByClient)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password
|
||||||
|
// subnegotiation) and returns a connected net.Conn ready for data exchange.
|
||||||
|
func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(targetAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
targetPort, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var method authMethod
|
||||||
|
if username != "" || password != "" {
|
||||||
|
method = authUsernamePassword
|
||||||
|
} else {
|
||||||
|
method = authNotRequired
|
||||||
|
}
|
||||||
|
_, err = conn.Write([]byte{socks5Version, 1, byte(method)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var methodResp [2]byte
|
||||||
|
_, err = io.ReadFull(conn, methodResp[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, socks5Version, methodResp[0])
|
||||||
|
require.Equal(t, byte(method), methodResp[1])
|
||||||
|
|
||||||
|
if method == authUsernamePassword {
|
||||||
|
packet := []byte{authUsernamePasswordSubNegotiation1, byte(len(username))}
|
||||||
|
packet = append(packet, []byte(username)...)
|
||||||
|
packet = append(packet, byte(len(password)))
|
||||||
|
packet = append(packet, []byte(password)...)
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var subnegResp [2]byte
|
||||||
|
_, err = io.ReadFull(conn, subnegResp[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0])
|
||||||
|
require.Equal(t, byte(0), subnegResp[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
var connectRequest []byte
|
||||||
|
if ip := net.ParseIP(host).To4(); ip != nil {
|
||||||
|
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
|
||||||
|
connectRequest = append(connectRequest, ip...)
|
||||||
|
} else {
|
||||||
|
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
|
||||||
|
connectRequest = append(connectRequest, []byte(host)...)
|
||||||
|
}
|
||||||
|
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
|
||||||
|
_, err = conn.Write(connectRequest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var responseHeader [4]byte
|
||||||
|
_, err = io.ReadFull(conn, responseHeader[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, socks5Version, responseHeader[0])
|
||||||
|
require.Equal(t, byte(succeeded), responseHeader[1])
|
||||||
|
|
||||||
|
// Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller).
|
||||||
|
switch addrType(responseHeader[3]) {
|
||||||
|
case ipv4:
|
||||||
|
var addrPort [net.IPv4len + 2]byte
|
||||||
|
_, err = io.ReadFull(conn, addrPort[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
case ipv6:
|
||||||
|
var addrPort [net.IPv6len + 2]byte
|
||||||
|
_, err = io.ReadFull(conn, addrPort[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
case domainName:
|
||||||
|
var lenBuf [1]byte
|
||||||
|
_, err = io.ReadFull(conn, lenBuf[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
addrPort := make([]byte, int(lenBuf[0])+2)
|
||||||
|
_, err = io.ReadFull(conn, addrPort)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newServer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
settings Settings
|
||||||
|
expected *server
|
||||||
|
}{
|
||||||
|
"with_auth": {
|
||||||
|
settings: Settings{
|
||||||
|
Username: "user",
|
||||||
|
Password: "pass",
|
||||||
|
Address: "127.0.0.1:1080",
|
||||||
|
},
|
||||||
|
expected: &server{
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
address: "127.0.0.1:1080",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"without_auth": {
|
||||||
|
settings: Settings{
|
||||||
|
Address: "127.0.0.1:1080",
|
||||||
|
},
|
||||||
|
expected: &server{
|
||||||
|
address: "127.0.0.1:1080",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
result := newServer(testCase.settings)
|
||||||
|
assert.Equal(t, testCase.expected.username, result.username)
|
||||||
|
assert.Equal(t, testCase.expected.password, result.password)
|
||||||
|
assert.Equal(t, testCase.expected.address, result.address)
|
||||||
|
assert.Equal(t, testCase.expected.logger, result.logger)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Server_StartStop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any())
|
||||||
|
|
||||||
|
server := newServer(Settings{
|
||||||
|
Address: "127.0.0.1:0",
|
||||||
|
Logger: logger,
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr, startErr := server.Start(t.Context())
|
||||||
|
require.NoError(t, startErr)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-runErr:
|
||||||
|
t.Fatalf("unexpected error on start: %v", err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
address := server.listeningAddress()
|
||||||
|
assert.NotNil(t, address)
|
||||||
|
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_encodeBindData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
addrType addrType
|
||||||
|
address string
|
||||||
|
port uint16
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
"ipv4_valid": {
|
||||||
|
addrType: ipv4,
|
||||||
|
address: "127.0.0.1",
|
||||||
|
port: 8080,
|
||||||
|
},
|
||||||
|
"ipv6_valid": {
|
||||||
|
addrType: ipv6,
|
||||||
|
address: "::1",
|
||||||
|
port: 8080,
|
||||||
|
},
|
||||||
|
"domain_name_valid": {
|
||||||
|
addrType: domainName,
|
||||||
|
address: "example.com",
|
||||||
|
port: 8080,
|
||||||
|
},
|
||||||
|
"ipv4_invalid": {
|
||||||
|
addrType: ipv4,
|
||||||
|
address: "invalid",
|
||||||
|
expectedErr: "parsing IP address",
|
||||||
|
},
|
||||||
|
"ipv4_actual_ipv6": {
|
||||||
|
addrType: ipv4,
|
||||||
|
address: "::1",
|
||||||
|
expectedErr: "ip version is unexpected",
|
||||||
|
},
|
||||||
|
"ipv6_actual_ipv4": {
|
||||||
|
addrType: ipv6,
|
||||||
|
address: "127.0.0.1",
|
||||||
|
expectedErr: "ip version is unexpected",
|
||||||
|
},
|
||||||
|
"domain_too_long": {
|
||||||
|
addrType: domainName,
|
||||||
|
address: strings.Repeat("a", 256),
|
||||||
|
expectedErr: "domain name is too long",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
data, err := encodeBindData(testCase.addrType, testCase.address, testCase.port)
|
||||||
|
|
||||||
|
if testCase.expectedErr != "" {
|
||||||
|
assert.ErrorContains(t, err, testCase.expectedErr)
|
||||||
|
assert.Nil(t, data)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
|
||||||
|
assert.Equal(t, byte(testCase.addrType), data[0])
|
||||||
|
|
||||||
|
portOffset := len(data) - 2
|
||||||
|
decodedPort := binary.BigEndian.Uint16(data[portOffset:])
|
||||||
|
assert.Equal(t, testCase.port, decodedPort)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_decodeRequest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
packet []byte
|
||||||
|
expectedErr string
|
||||||
|
validate func(*testing.T, request)
|
||||||
|
}{
|
||||||
|
"ipv4_valid": {
|
||||||
|
packet: []byte{socks5Version, byte(connect), 0, byte(ipv4), 127, 0, 0, 1, byte(0x1f), byte(0x90)},
|
||||||
|
validate: func(t *testing.T, request request) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, connect, request.command)
|
||||||
|
assert.Equal(t, "127.0.0.1", request.destination)
|
||||||
|
assert.Equal(t, uint16(8080), request.port)
|
||||||
|
assert.Equal(t, ipv4, request.addressType)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"domain_name_valid": {
|
||||||
|
packet: concatBytes(
|
||||||
|
[]byte{socks5Version, byte(connect), 0, byte(domainName)},
|
||||||
|
[]byte{byte(len("example.com"))},
|
||||||
|
[]byte("example.com"),
|
||||||
|
[]byte{0x00, 0x50},
|
||||||
|
),
|
||||||
|
validate: func(t *testing.T, request request) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, "example.com", request.destination)
|
||||||
|
assert.Equal(t, uint16(80), request.port)
|
||||||
|
assert.Equal(t, domainName, request.addressType)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"version_mismatch": {
|
||||||
|
packet: []byte{4, byte(connect), 0, byte(ipv4), 127, 0, 0, 1, 0, 0},
|
||||||
|
expectedErr: "version is not supported",
|
||||||
|
},
|
||||||
|
"truncated_header": {
|
||||||
|
packet: []byte{socks5Version, byte(connect)},
|
||||||
|
expectedErr: "reading header",
|
||||||
|
},
|
||||||
|
"unsupported_address_type": {
|
||||||
|
packet: []byte{socks5Version, byte(connect), 0, byte(255)},
|
||||||
|
expectedErr: "address type is not supported",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
reader := bytes.NewReader(testCase.packet)
|
||||||
|
|
||||||
|
request, err := decodeRequest(reader, socks5Version)
|
||||||
|
|
||||||
|
if testCase.expectedErr != "" {
|
||||||
|
assert.ErrorContains(t, err, testCase.expectedErr)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
testCase.validate(t, request)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_verifyFirstNegotiation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
packet []byte
|
||||||
|
requiredAuth authMethod
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
"version_mismatch": {
|
||||||
|
packet: []byte{4, 2, byte(authNotRequired), byte(authUsernamePassword)},
|
||||||
|
requiredAuth: authNotRequired,
|
||||||
|
expectedErr: "version is not supported",
|
||||||
|
},
|
||||||
|
"no_methods": {
|
||||||
|
packet: []byte{socks5Version, 0},
|
||||||
|
requiredAuth: authNotRequired,
|
||||||
|
expectedErr: "no method identifiers",
|
||||||
|
},
|
||||||
|
"required_method_not_present": {
|
||||||
|
packet: []byte{socks5Version, 2, byte(authNotRequired), byte(authGssapi)},
|
||||||
|
requiredAuth: authUsernamePassword,
|
||||||
|
expectedErr: "no valid method identifier",
|
||||||
|
},
|
||||||
|
"required_method_present": {
|
||||||
|
packet: []byte{socks5Version, 3, byte(authNotRequired), byte(authUsernamePassword), byte(authGssapi)},
|
||||||
|
requiredAuth: authUsernamePassword,
|
||||||
|
},
|
||||||
|
"no_auth_required": {
|
||||||
|
packet: []byte{socks5Version, 1, byte(authNotRequired)},
|
||||||
|
requiredAuth: authNotRequired,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
reader := bytes.NewReader(testCase.packet)
|
||||||
|
|
||||||
|
err := verifyFirstNegotiation(reader, testCase.requiredAuth)
|
||||||
|
|
||||||
|
if testCase.expectedErr != "" {
|
||||||
|
assert.ErrorContains(t, err, testCase.expectedErr)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_usernamePasswordSubnegotiate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
packet []byte
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
"valid_credentials": {
|
||||||
|
packet: concatBytes(
|
||||||
|
[]byte{authUsernamePasswordSubNegotiation1, 4},
|
||||||
|
[]byte("user"),
|
||||||
|
[]byte{4},
|
||||||
|
[]byte("pass"),
|
||||||
|
),
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
},
|
||||||
|
"version_mismatch": {
|
||||||
|
packet: []byte{2, 4, 'u', 's', 'e', 'r'},
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
expectedErr: "subnegotiation version not supported",
|
||||||
|
},
|
||||||
|
"wrong_username": {
|
||||||
|
packet: concatBytes(
|
||||||
|
[]byte{authUsernamePasswordSubNegotiation1, 4},
|
||||||
|
[]byte("fake"),
|
||||||
|
[]byte{4},
|
||||||
|
[]byte("pass"),
|
||||||
|
),
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
expectedErr: "username received is not valid",
|
||||||
|
},
|
||||||
|
"wrong_password": {
|
||||||
|
packet: concatBytes(
|
||||||
|
[]byte{authUsernamePasswordSubNegotiation1, 4},
|
||||||
|
[]byte("user"),
|
||||||
|
[]byte{4},
|
||||||
|
[]byte("fake"),
|
||||||
|
),
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
expectedErr: "password not valid",
|
||||||
|
},
|
||||||
|
"truncated_header": {
|
||||||
|
packet: []byte{authUsernamePasswordSubNegotiation1},
|
||||||
|
username: "user",
|
||||||
|
password: "pass",
|
||||||
|
expectedErr: "reading header",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
buffer := bytes.NewBuffer(testCase.packet)
|
||||||
|
|
||||||
|
readWriter := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{
|
||||||
|
Reader: buffer,
|
||||||
|
Writer: io.Discard,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := usernamePasswordSubnegotiate(readWriter, testCase.username, testCase.password)
|
||||||
|
|
||||||
|
if testCase.expectedErr != "" {
|
||||||
|
assert.ErrorContains(t, err, testCase.expectedErr)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func concatBytes(slices ...[]byte) []byte {
|
||||||
|
var result []byte
|
||||||
|
for _, slice := range slices {
|
||||||
|
result = append(result, slice...)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_bindDataLength(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
addrType addrType
|
||||||
|
address string
|
||||||
|
wantMaxLength uint
|
||||||
|
}{
|
||||||
|
"ipv4": {
|
||||||
|
addrType: ipv4,
|
||||||
|
address: "127.0.0.1",
|
||||||
|
wantMaxLength: 1 + 4 + 2,
|
||||||
|
},
|
||||||
|
"ipv6": {
|
||||||
|
addrType: ipv6,
|
||||||
|
address: "::1",
|
||||||
|
wantMaxLength: 1 + 16 + 2,
|
||||||
|
},
|
||||||
|
"domain_short": {
|
||||||
|
addrType: domainName,
|
||||||
|
address: "example.com",
|
||||||
|
wantMaxLength: 1 + 1 + uint(len("example.com")) + 2,
|
||||||
|
},
|
||||||
|
"domain_long": {
|
||||||
|
addrType: domainName,
|
||||||
|
address: strings.Repeat("a", 100),
|
||||||
|
wantMaxLength: 1 + 1 + 100 + 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
length := bindDataLength(testCase.addrType, testCase.address)
|
||||||
|
assert.Equal(t, testCase.wantMaxLength, length)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_authMethod_String(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
method authMethod
|
||||||
|
expectedName string
|
||||||
|
}{
|
||||||
|
"no_auth": {
|
||||||
|
method: authNotRequired,
|
||||||
|
expectedName: "no authentication required",
|
||||||
|
},
|
||||||
|
"gssapi": {
|
||||||
|
method: authGssapi,
|
||||||
|
expectedName: "GSSAPI",
|
||||||
|
},
|
||||||
|
"username_password": {
|
||||||
|
method: authUsernamePassword,
|
||||||
|
expectedName: "username/password",
|
||||||
|
},
|
||||||
|
"not_acceptable": {
|
||||||
|
method: authNotAcceptable,
|
||||||
|
expectedName: "no acceptable methods",
|
||||||
|
},
|
||||||
|
"unknown": {
|
||||||
|
method: authMethod(99),
|
||||||
|
expectedName: "unknown method (99)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
result := testCase.method.String()
|
||||||
|
assert.Equal(t, testCase.expectedName, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_cmdType_String(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]struct {
|
||||||
|
cmd cmdType
|
||||||
|
expectedName string
|
||||||
|
}{
|
||||||
|
"connect": {
|
||||||
|
cmd: connect,
|
||||||
|
expectedName: "connect",
|
||||||
|
},
|
||||||
|
"bind": {
|
||||||
|
cmd: bind,
|
||||||
|
expectedName: "bind",
|
||||||
|
},
|
||||||
|
"udp_associate": {
|
||||||
|
cmd: udpAssociate,
|
||||||
|
expectedName: "UDP associate",
|
||||||
|
},
|
||||||
|
"unknown": {
|
||||||
|
cmd: cmdType(99),
|
||||||
|
expectedName: "unknown command (99)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
result := testCase.cmd.String()
|
||||||
|
assert.Equal(t, testCase.expectedName, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
||||||
|
func usernamePasswordSubnegotiate(conn io.ReadWriter, username, password string) (err error) {
|
||||||
|
status := byte(1)
|
||||||
|
const defaultVersion = byte(1)
|
||||||
|
|
||||||
|
const headerLength = 2
|
||||||
|
var header [headerLength]byte
|
||||||
|
_, err = io.ReadFull(conn, header[:])
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{defaultVersion, status})
|
||||||
|
return fmt.Errorf("reading header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header[0] != authUsernamePasswordSubNegotiation1 {
|
||||||
|
_, _ = conn.Write([]byte{defaultVersion, status})
|
||||||
|
return fmt.Errorf("subnegotiation version not supported: %d", header[0])
|
||||||
|
}
|
||||||
|
version := header[0]
|
||||||
|
|
||||||
|
usernameBytes := make([]byte, header[1])
|
||||||
|
_, err = io.ReadFull(conn, usernameBytes)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("reading username bytes: %w", err)
|
||||||
|
} else if username != string(usernameBytes) {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("username received is not valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
const passwordHeaderLength = 1
|
||||||
|
passwordHeader := make([]byte, passwordHeaderLength)
|
||||||
|
_, err = io.ReadFull(conn, passwordHeader)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("reading password length: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordBytes := make([]byte, passwordHeader[0])
|
||||||
|
_, err = io.ReadFull(conn, passwordBytes)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("reading password bytes: %w", err)
|
||||||
|
} else if password != string(passwordBytes) {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("password not valid for username %q", string(usernameBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
status = 0
|
||||||
|
_, err = conn.Write([]byte{version, status})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing success status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user