76 lines
2.1 KiB
Go
76 lines
2.1 KiB
Go
package networking
|
|
|
|
import (
|
|
"context"
|
|
"github.com/pkg/errors"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"strconv"
|
|
)
|
|
|
|
// RoundrobinDialer is a wrapper for the DialContext function.
|
|
type (
|
|
RoundrobinDialer struct {
|
|
Dialer *net.Dialer
|
|
FallbackHost net.TCPAddr
|
|
}
|
|
)
|
|
|
|
// DialContext is a connector method with Fail-Over approach.
|
|
func (d *RoundrobinDialer) DialContext(ctx context.Context, network string, hosts ...string) (net.Conn, error) {
|
|
list := rand.Perm(len(hosts))
|
|
for _, idx := range list {
|
|
host := hosts[idx]
|
|
for _, resolvedIP := range d.resolveHost(ctx, network, host) {
|
|
var (
|
|
conn net.Conn
|
|
err error
|
|
)
|
|
if conn, err = d.Dialer.DialContext(ctx, resolvedIP.Network(), resolvedIP.AddrPort().String()); err != nil {
|
|
log.Printf("failed to connected to %s => %s : %s", host, resolvedIP.String(), err)
|
|
} else {
|
|
return conn, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
if !d.FallbackHost.IP.IsUnspecified() {
|
|
log.Printf("attempting to connect fallback address(%s)\n", d.FallbackHost.String())
|
|
return d.Dialer.DialContext(ctx, d.FallbackHost.Network(), d.FallbackHost.AddrPort().String())
|
|
}
|
|
return nil, errors.New("name resolve failure")
|
|
}
|
|
|
|
// resolveHost is actually resolve network addresses from a given hostname.
|
|
func (d *RoundrobinDialer) resolveHost(ctx context.Context, network, connectAddr string) (addrs []net.TCPAddr) {
|
|
log.Printf("attempting to connect %s", connectAddr)
|
|
var (
|
|
host string
|
|
portString string
|
|
resolvedIPs []net.IP
|
|
tcpPort int
|
|
err error
|
|
)
|
|
if host, portString, err = net.SplitHostPort(connectAddr); err != nil {
|
|
log.Printf("invalid connection string format: %s", err)
|
|
return
|
|
} else if port, _ := strconv.ParseInt(portString, 10, 32); port < 0 || port > 65535 {
|
|
log.Printf("invalid port format : %s", portString)
|
|
return
|
|
} else {
|
|
tcpPort = int(port)
|
|
}
|
|
|
|
if resolvedIPs, err = d.Dialer.Resolver.LookupIP(ctx, network, host); err != nil {
|
|
log.Printf("cannot resolve host %s(%s) : %s", host, network, err)
|
|
return
|
|
}
|
|
|
|
addrs = make([]net.TCPAddr, 0, len(resolvedIPs))
|
|
for _, resolvedIP := range resolvedIPs {
|
|
addrs = append(addrs, net.TCPAddr{IP: resolvedIP, Port: tcpPort})
|
|
}
|
|
return
|
|
}
|