go
/
misc
1
0
Fork 0
misc/networking/failback.go

76 lines
2.1 KiB
Go

package networking
import (
"context"
"github.com/pkg/errors"
"log"
"math/rand"
"net"
"strconv"
)
// RoundrobinDialer is a wrapper for the DialContext function.
type (
RoundrobinDialer struct {
Dialer *net.Dialer
FallbackHost net.TCPAddr
}
)
// DialContext is a connector method with Fail-Over approach.
func (d *RoundrobinDialer) DialContext(ctx context.Context, network string, hosts ...string) (net.Conn, error) {
list := rand.Perm(len(hosts))
for _, idx := range list {
host := hosts[idx]
for _, resolvedIP := range d.resolveHost(ctx, network, host) {
var (
conn net.Conn
err error
)
if conn, err = d.Dialer.DialContext(ctx, resolvedIP.Network(), resolvedIP.AddrPort().String()); err != nil {
log.Printf("failed to connected to %s => %s : %s", host, resolvedIP.String(), err)
} else {
return conn, nil
}
}
}
if !d.FallbackHost.IP.IsUnspecified() {
log.Printf("attempting to connect fallback address(%s)\n", d.FallbackHost.String())
return d.Dialer.DialContext(ctx, d.FallbackHost.Network(), d.FallbackHost.AddrPort().String())
}
return nil, errors.New("name resolve failure")
}
// resolveHost is actually resolve network addresses from a given hostname.
func (d *RoundrobinDialer) resolveHost(ctx context.Context, network, connectAddr string) (addrs []net.TCPAddr) {
log.Printf("attempting to connect %s", connectAddr)
var (
host string
portString string
resolvedIPs []net.IP
tcpPort int
err error
)
if host, portString, err = net.SplitHostPort(connectAddr); err != nil {
log.Printf("invalid connection string format: %s", err)
return
} else if port, _ := strconv.ParseInt(portString, 10, 32); port < 0 || port > 65535 {
log.Printf("invalid port format : %s", portString)
return
} else {
tcpPort = int(port)
}
if resolvedIPs, err = d.Dialer.Resolver.LookupIP(ctx, network, host); err != nil {
log.Printf("cannot resolve host %s(%s) : %s", host, network, err)
return
}
addrs = make([]net.TCPAddr, 0, len(resolvedIPs))
for _, resolvedIP := range resolvedIPs {
addrs = append(addrs, net.TCPAddr{IP: resolvedIP, Port: tcpPort})
}
return
}