fix(provider/pia): handle "port is busy" messages and retry port forwarding logic

This commit is contained in:
Quentin McGaw
2026-05-08 04:16:15 +00:00
parent 5cae870745
commit 891249849a
@@ -13,6 +13,7 @@ import (
"net/netip"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"time"
@@ -79,13 +80,27 @@ func (p *Provider) PortForward(ctx context.Context,
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding
if err := bindPort(ctx, privateIPClient, p.apiIP, data); err != nil {
return nil, fmt.Errorf("binding port: %w", err)
for ctx.Err() == nil {
err = bindPort(ctx, privateIPClient, p.apiIP, data)
if err == nil {
break
} else if !errors.Is(err, errPortBusy) {
return nil, fmt.Errorf("binding port: %w", err)
}
logger.Warn("refreshing port forward data and trying again because " + err.Error())
client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, p.apiIP,
p.portForwardPath, objects.Username, objects.Password)
if err != nil {
return nil, fmt.Errorf("refreshing port forward data: %w", err)
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
return map[uint16]uint16{data.Port: data.Port}, nil
}
@@ -393,6 +408,13 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.
return port, data.Signature, expiration, err
}
var errPortBusy = errors.New("port is busy")
var (
regexPortBusy = regexp.MustCompile(`^port \d+ is busy\. `)
regexNumber = regexp.MustCompile(`\d+`)
)
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
// Define a timeout since the default client has a large timeout and we don't
// want to wait too long.
@@ -431,7 +453,9 @@ func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr,
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
switch response.StatusCode {
case http.StatusOK, http.StatusConflict:
default:
return makeNOKStatusError(response, errSubstitutions)
}
@@ -444,11 +468,24 @@ func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr,
return fmt.Errorf("decoding response: from %s: %w", bindPortURL.String(), err)
}
if responseData.Status != "OK" {
return fmt.Errorf("bad response received with status %q and message %q", responseData.Status, responseData.Message)
switch response.StatusCode {
case http.StatusOK:
if responseData.Status != "OK" {
return fmt.Errorf("bad response received with status %q and message %q", responseData.Status, responseData.Message)
}
return nil
case http.StatusConflict:
portIsBusy := regexPortBusy.FindString(responseData.Message)
if portIsBusy == "" {
return fmt.Errorf("port busy response received with unexpected message %q not matching regex %q",
responseData.Message, regexPortBusy.String())
}
portStr := regexNumber.FindString(portIsBusy)
rest := strings.TrimPrefix(responseData.Message, portIsBusy)
return fmt.Errorf("%w: %s - %s", errPortBusy, portStr, rest)
default:
panic("unreachable code")
}
return nil
}
// replaceInErr is used to remove sensitive information from errors.