diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 80150481..9c8d9abb 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -13,6 +13,7 @@ import ( "net/netip" "net/url" "os" + "regexp" "strconv" "strings" "time" @@ -79,13 +80,27 @@ func (p *Provider) PortForward(ctx context.Context, } durationToExpiration = data.Expiration.Sub(p.timeNow()) } - logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration)) // First time binding - if err := bindPort(ctx, privateIPClient, p.apiIP, data); err != nil { - return nil, fmt.Errorf("binding port: %w", err) + for ctx.Err() == nil { + err = bindPort(ctx, privateIPClient, p.apiIP, data) + if err == nil { + break + } else if !errors.Is(err, errPortBusy) { + return nil, fmt.Errorf("binding port: %w", err) + } + logger.Warn("refreshing port forward data and trying again because " + err.Error()) + client := objects.Client + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, p.apiIP, + p.portForwardPath, objects.Username, objects.Password) + if err != nil { + return nil, fmt.Errorf("refreshing port forward data: %w", err) + } + durationToExpiration = data.Expiration.Sub(p.timeNow()) } + logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration)) + return map[uint16]uint16{data.Port: data.Port}, nil } @@ -393,6 +408,13 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip. return port, data.Signature, expiration, err } +var errPortBusy = errors.New("port is busy") + +var ( + regexPortBusy = regexp.MustCompile(`^port \d+ is busy\. `) + regexNumber = regexp.MustCompile(`\d+`) +) + func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) { // Define a timeout since the default client has a large timeout and we don't // want to wait too long. @@ -431,7 +453,9 @@ func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, } defer response.Body.Close() - if response.StatusCode != http.StatusOK { + switch response.StatusCode { + case http.StatusOK, http.StatusConflict: + default: return makeNOKStatusError(response, errSubstitutions) } @@ -444,11 +468,24 @@ func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, return fmt.Errorf("decoding response: from %s: %w", bindPortURL.String(), err) } - if responseData.Status != "OK" { - return fmt.Errorf("bad response received with status %q and message %q", responseData.Status, responseData.Message) + switch response.StatusCode { + case http.StatusOK: + if responseData.Status != "OK" { + return fmt.Errorf("bad response received with status %q and message %q", responseData.Status, responseData.Message) + } + return nil + case http.StatusConflict: + portIsBusy := regexPortBusy.FindString(responseData.Message) + if portIsBusy == "" { + return fmt.Errorf("port busy response received with unexpected message %q not matching regex %q", + responseData.Message, regexPortBusy.String()) + } + portStr := regexNumber.FindString(portIsBusy) + rest := strings.TrimPrefix(responseData.Message, portIsBusy) + return fmt.Errorf("%w: %s - %s", errPortBusy, portStr, rest) + default: + panic("unreachable code") } - - return nil } // replaceInErr is used to remove sensitive information from errors.