You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

1352 lines
34 KiB

package kafka
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"math/rand"
"net"
"runtime/pprof"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/segmentio/kafka-go/protocol"
"github.com/segmentio/kafka-go/protocol/apiversions"
"github.com/segmentio/kafka-go/protocol/createtopics"
"github.com/segmentio/kafka-go/protocol/findcoordinator"
meta "github.com/segmentio/kafka-go/protocol/metadata"
"github.com/segmentio/kafka-go/protocol/saslauthenticate"
"github.com/segmentio/kafka-go/protocol/saslhandshake"
"github.com/segmentio/kafka-go/sasl"
)
// Request is an interface implemented by types that represent messages sent
// from kafka clients to brokers.
type Request = protocol.Message
// Response is an interface implemented by types that represent messages sent
// from kafka brokers in response to client requests.
type Response = protocol.Message
// RoundTripper is an interface implemented by types which support interacting
// with kafka brokers.
type RoundTripper interface {
// RoundTrip sends a request to a kafka broker and returns the response that
// was received, or a non-nil error.
//
// The context passed as first argument can be used to asynchronnously abort
// the call if needed.
RoundTrip(context.Context, net.Addr, Request) (Response, error)
}
// Transport is an implementation of the RoundTripper interface.
//
// Transport values manage a pool of connections and automatically discovers the
// clusters layout to route requests to the appropriate brokers.
//
// Transport values are safe to use concurrently from multiple goroutines.
//
// Note: The intent is for the Transport to become the underlying layer of the
// kafka.Reader and kafka.Writer types.
type Transport struct {
// A function used to establish connections to the kafka cluster.
Dial func(context.Context, string, string) (net.Conn, error)
// Time limit set for establishing connections to the kafka cluster. This
// limit includes all round trips done to establish the connections (TLS
// hadbhaske, SASL negotiation, etc...).
//
// Defaults to 5s.
DialTimeout time.Duration
// Maximum amount of time that connections will remain open and unused.
// The transport will manage to automatically close connections that have
// been idle for too long, and re-open them on demand when the transport is
// used again.
//
// Defaults to 30s.
IdleTimeout time.Duration
// TTL for the metadata cached by this transport. Note that the value
// configured here is an upper bound, the transport randomizes the TTLs to
// avoid getting into states where multiple clients end up synchronized and
// cause bursts of requests to the kafka broker.
//
// Default to 6s.
MetadataTTL time.Duration
// Unique identifier that the transport communicates to the brokers when it
// sends requests.
ClientID string
// An optional configuration for TLS connections established by this
// transport.
//
// If the Server
TLS *tls.Config
// SASL configures the Transfer to use SASL authentication.
SASL sasl.Mechanism
// An optional resolver used to translate broker host names into network
// addresses.
//
// The resolver will be called for every request (not every connection),
// making it possible to implement ACL policies by validating that the
// program is allowed to connect to the kafka broker. This also means that
// the resolver should probably provide a caching layer to avoid storming
// the service discovery backend with requests.
//
// When set, the Dial function is not responsible for performing name
// resolution, and is always called with a pre-resolved address.
Resolver BrokerResolver
// The background context used to control goroutines started internally by
// the transport.
//
// If nil, context.Background() is used instead.
Context context.Context
mutex sync.RWMutex
pools map[networkAddress]*connPool
}
// DefaultTransport is the default transport used by kafka clients in this
// package.
var DefaultTransport RoundTripper = &Transport{
Dial: (&net.Dialer{
Timeout: 3 * time.Second,
DualStack: true,
}).DialContext,
}
// CloseIdleConnections closes all idle connections immediately, and marks all
// connections that are in use to be closed when they become idle again.
func (t *Transport) CloseIdleConnections() {
t.mutex.Lock()
defer t.mutex.Unlock()
for _, pool := range t.pools {
pool.unref()
}
for k := range t.pools {
delete(t.pools, k)
}
}
// RoundTrip sends a request to a kafka cluster and returns the response, or an
// error if no responses were received.
//
// Message types are available in sub-packages of the protocol package. Each
// kafka API is implemented in a different sub-package. For example, the request
// and response types for the Fetch API are available in the protocol/fetch
// package.
//
// The type of the response message will match the type of the request. For
// exmple, if RoundTrip was called with a *fetch.Request as argument, the value
// returned will be of type *fetch.Response. It is safe for the program to do a
// type assertion after checking that no error was returned.
//
// This example illustrates the way this method is expected to be used:
//
// r, err := transport.RoundTrip(ctx, addr, &fetch.Request{ ... })
// if err != nil {
// ...
// } else {
// res := r.(*fetch.Response)
// ...
// }
//
// The transport automatically selects the highest version of the API that is
// supported by both the kafka-go package and the kafka broker. The negotiation
// happens transparently once when connections are established.
//
// This API was introduced in version 0.4 as a way to leverage the lower-level
// features of the kafka protocol, but also provide a more efficient way of
// managing connections to kafka brokers.
func (t *Transport) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) {
p := t.grabPool(addr)
defer p.unref()
return p.roundTrip(ctx, req)
}
func (t *Transport) dial() func(context.Context, string, string) (net.Conn, error) {
if t.Dial != nil {
return t.Dial
}
return defaultDialer.DialContext
}
func (t *Transport) dialTimeout() time.Duration {
if t.DialTimeout > 0 {
return t.DialTimeout
}
return 5 * time.Second
}
func (t *Transport) idleTimeout() time.Duration {
if t.IdleTimeout > 0 {
return t.IdleTimeout
}
return 30 * time.Second
}
func (t *Transport) metadataTTL() time.Duration {
if t.MetadataTTL > 0 {
return t.MetadataTTL
}
return 6 * time.Second
}
func (t *Transport) grabPool(addr net.Addr) *connPool {
k := networkAddress{
network: addr.Network(),
address: addr.String(),
}
t.mutex.RLock()
p := t.pools[k]
if p != nil {
p.ref()
}
t.mutex.RUnlock()
if p != nil {
return p
}
t.mutex.Lock()
defer t.mutex.Unlock()
if p := t.pools[k]; p != nil {
p.ref()
return p
}
ctx, cancel := context.WithCancel(t.context())
p = &connPool{
refc: 2,
dial: t.dial(),
dialTimeout: t.dialTimeout(),
idleTimeout: t.idleTimeout(),
metadataTTL: t.metadataTTL(),
clientID: t.ClientID,
tls: t.TLS,
sasl: t.SASL,
resolver: t.Resolver,
ready: make(event),
wake: make(chan event),
conns: make(map[int32]*connGroup),
cancel: cancel,
}
p.ctrl = p.newConnGroup(addr)
go p.discover(ctx, p.wake)
if t.pools == nil {
t.pools = make(map[networkAddress]*connPool)
}
t.pools[k] = p
return p
}
func (t *Transport) context() context.Context {
if t.Context != nil {
return t.Context
}
return context.Background()
}
type event chan struct{}
func (e event) trigger() { close(e) }
type connPool struct {
refc uintptr
// Immutable fields of the connection pool. Connections access these field
// on their parent pool in a ready-only fashion, so no synchronization is
// required.
dial func(context.Context, string, string) (net.Conn, error)
dialTimeout time.Duration
idleTimeout time.Duration
metadataTTL time.Duration
clientID string
tls *tls.Config
sasl sasl.Mechanism
resolver BrokerResolver
// Signaling mechanisms to orchestrate communications between the pool and
// the rest of the program.
once sync.Once // ensure that `ready` is triggered only once
ready event // triggered after the first metadata update
wake chan event // used to force metadata updates
cancel context.CancelFunc
// Mutable fields of the connection pool, access must be synchronized.
mutex sync.RWMutex
conns map[int32]*connGroup // data connections used for produce/fetch/etc...
ctrl *connGroup // control connections used for metadata requests
state atomic.Value // cached cluster state
}
type connPoolState struct {
metadata *meta.Response // last metadata response seen by the pool
err error // last error from metadata requests
layout protocol.Cluster // cluster layout built from metadata response
}
func (p *connPool) grabState() connPoolState {
state, _ := p.state.Load().(connPoolState)
return state
}
func (p *connPool) setState(state connPoolState) {
p.state.Store(state)
}
func (p *connPool) ref() {
atomic.AddUintptr(&p.refc, +1)
}
func (p *connPool) unref() {
if atomic.AddUintptr(&p.refc, ^uintptr(0)) == 0 {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, conns := range p.conns {
conns.closeIdleConns()
}
p.ctrl.closeIdleConns()
p.cancel()
}
}
func (p *connPool) roundTrip(ctx context.Context, req Request) (Response, error) {
// This first select should never block after the first metadata response
// that would mark the pool as `ready`.
select {
case <-p.ready:
case <-ctx.Done():
return nil, ctx.Err()
}
state := p.grabState()
var response promise
switch m := req.(type) {
case *meta.Request:
// We serve metadata requests directly from the transport cache unless
// we would like to auto create a topic that isn't in our cache.
//
// This reduces the number of round trips to kafka brokers while keeping
// the logic simple when applying partitioning strategies.
if state.err != nil {
return nil, state.err
}
cachedMeta := filterMetadataResponse(m, state.metadata)
// requestNeeded indicates if we need to send this metadata request to the server.
// It's true when we want to auto-create topics and we don't have the topic in our
// cache.
var requestNeeded bool
if m.AllowAutoTopicCreation {
for _, topic := range cachedMeta.Topics {
if topic.ErrorCode == int16(UnknownTopicOrPartition) {
requestNeeded = true
break
}
}
}
if !requestNeeded {
return cachedMeta, nil
}
case protocol.Splitter:
// Messages that implement the Splitter interface trigger the creation of
// multiple requests that are all merged back into a single results by
// a merger.
messages, merger, err := m.Split(state.layout)
if err != nil {
return nil, err
}
promises := make([]promise, len(messages))
for i, m := range messages {
promises[i] = p.sendRequest(ctx, m, state)
}
response = join(promises, messages, merger)
}
if response == nil {
response = p.sendRequest(ctx, req, state)
}
r, err := response.await(ctx)
if err != nil {
return r, err
}
switch resp := r.(type) {
case *createtopics.Response:
// Force an update of the metadata when adding topics,
// otherwise the cached state would get out of sync.
topicsToRefresh := make([]string, 0, len(resp.Topics))
for _, topic := range resp.Topics {
// fixes issue 672: don't refresh topics that failed to create, it causes the library to hang indefinitely
if topic.ErrorCode != 0 {
continue
}
topicsToRefresh = append(topicsToRefresh, topic.Name)
}
p.refreshMetadata(ctx, topicsToRefresh)
case *meta.Response:
m := req.(*meta.Request)
// If we get here with allow auto topic creation then
// we didn't have that topic in our cache so we should update
// the cache.
if m.AllowAutoTopicCreation {
topicsToRefresh := make([]string, 0, len(resp.Topics))
for _, topic := range resp.Topics {
// fixes issue 806: don't refresh topics that failed to create,
// it may means kafka doesn't enable auto topic creation.
// This causes the library to hang indefinitely, same as createtopics process.
if topic.ErrorCode != 0 {
continue
}
topicsToRefresh = append(topicsToRefresh, topic.Name)
}
p.refreshMetadata(ctx, topicsToRefresh)
}
}
return r, nil
}
// refreshMetadata forces an update of the cached cluster metadata, and waits
// for the given list of topics to appear. This waiting mechanism is necessary
// to account for the fact that topic creation is asynchronous in kafka, and
// causes subsequent requests to fail while the cluster state is propagated to
// all the brokers.
func (p *connPool) refreshMetadata(ctx context.Context, expectTopics []string) {
minBackoff := 100 * time.Millisecond
maxBackoff := 2 * time.Second
cancel := ctx.Done()
for ctx.Err() == nil {
notify := make(event)
select {
case <-cancel:
return
case p.wake <- notify:
select {
case <-notify:
case <-cancel:
return
}
}
state := p.grabState()
found := 0
for _, topic := range expectTopics {
if _, ok := state.layout.Topics[topic]; ok {
found++
}
}
if found == len(expectTopics) {
return
}
if delay := time.Duration(rand.Int63n(int64(minBackoff))); delay > 0 {
timer := time.NewTimer(minBackoff)
select {
case <-cancel:
case <-timer.C:
}
timer.Stop()
if minBackoff *= 2; minBackoff > maxBackoff {
minBackoff = maxBackoff
}
}
}
}
func (p *connPool) setReady() {
p.once.Do(p.ready.trigger)
}
// update is called periodically by the goroutine running the discover method
// to refresh the cluster layout information used by the transport to route
// requests to brokers.
func (p *connPool) update(ctx context.Context, metadata *meta.Response, err error) {
var layout protocol.Cluster
if metadata != nil {
metadata.ThrottleTimeMs = 0
// Normalize the lists so we can apply binary search on them.
sortMetadataBrokers(metadata.Brokers)
sortMetadataTopics(metadata.Topics)
for i := range metadata.Topics {
t := &metadata.Topics[i]
sortMetadataPartitions(t.Partitions)
}
layout = makeLayout(metadata)
}
state := p.grabState()
addBrokers := make(map[int32]struct{})
delBrokers := make(map[int32]struct{})
if err != nil {
// Only update the error on the transport if the cluster layout was
// unknown. This ensures that we prioritize a previously known state
// of the cluster to reduce the impact of transient failures.
if state.metadata != nil {
return
}
state.err = err
} else {
for id, b2 := range layout.Brokers {
if b1, ok := state.layout.Brokers[id]; !ok {
addBrokers[id] = struct{}{}
} else if b1 != b2 {
addBrokers[id] = struct{}{}
delBrokers[id] = struct{}{}
}
}
for id := range state.layout.Brokers {
if _, ok := layout.Brokers[id]; !ok {
delBrokers[id] = struct{}{}
}
}
state.metadata, state.layout = metadata, layout
state.err = nil
}
defer p.setReady()
defer p.setState(state)
if len(addBrokers) != 0 || len(delBrokers) != 0 {
// Only acquire the lock when there is a change of layout. This is an
// infrequent event so we don't risk introducing regular contention on
// the mutex if we were to lock it on every update.
p.mutex.Lock()
defer p.mutex.Unlock()
if ctx.Err() != nil {
return // the pool has been closed, no need to update
}
for id := range delBrokers {
if broker := p.conns[id]; broker != nil {
broker.closeIdleConns()
delete(p.conns, id)
}
}
for id := range addBrokers {
broker := layout.Brokers[id]
p.conns[id] = p.newBrokerConnGroup(Broker{
Rack: broker.Rack,
Host: broker.Host,
Port: int(broker.Port),
ID: int(broker.ID),
})
}
}
}
// discover is the entry point of an internal goroutine for the transport which
// periodically requests updates of the cluster metadata and refreshes the
// transport cached cluster layout.
func (p *connPool) discover(ctx context.Context, wake <-chan event) {
prng := rand.New(rand.NewSource(time.Now().UnixNano()))
metadataTTL := func() time.Duration {
return time.Duration(prng.Int63n(int64(p.metadataTTL)))
}
timer := time.NewTimer(metadataTTL())
defer timer.Stop()
var notify event
done := ctx.Done()
for {
c, err := p.grabClusterConn(ctx)
if err != nil {
p.update(ctx, nil, err)
} else {
res := make(async, 1)
req := &meta.Request{}
deadline, cancel := context.WithTimeout(ctx, p.metadataTTL)
c.reqs <- connRequest{
ctx: deadline,
req: req,
res: res,
}
r, err := res.await(deadline)
cancel()
if err != nil && errors.Is(err, ctx.Err()) {
return
}
ret, _ := r.(*meta.Response)
p.update(ctx, ret, err)
}
if notify != nil {
notify.trigger()
notify = nil
}
select {
case <-timer.C:
timer.Reset(metadataTTL())
case <-done:
return
case notify = <-wake:
}
}
}
// grabBrokerConn returns a connection to a specific broker represented by the
// broker id passed as argument. If the broker id was not known, an error is
// returned.
func (p *connPool) grabBrokerConn(ctx context.Context, brokerID int32) (*conn, error) {
p.mutex.RLock()
g := p.conns[brokerID]
p.mutex.RUnlock()
if g == nil {
return nil, BrokerNotAvailable
}
return g.grabConnOrConnect(ctx)
}
// grabClusterConn returns the connection to the kafka cluster that the pool is
// configured to connect to.
//
// The transport uses a shared `control` connection to the cluster for any
// requests that aren't supposed to be sent to specific brokers (e.g. Fetch or
// Produce requests). Requests intended to be routed to specific brokers are
// dispatched on a separate pool of connections that the transport maintains.
// This split help avoid head-of-line blocking situations where control requests
// like Metadata would be queued behind large responses from Fetch requests for
// example.
//
// In either cases, the requests are multiplexed so we can keep a minimal number
// of connections open (N+1, where N is the number of brokers in the cluster).
func (p *connPool) grabClusterConn(ctx context.Context) (*conn, error) {
return p.ctrl.grabConnOrConnect(ctx)
}
func (p *connPool) sendRequest(ctx context.Context, req Request, state connPoolState) promise {
brokerID := int32(-1)
switch m := req.(type) {
case protocol.BrokerMessage:
// Some requests are supposed to be sent to specific brokers (e.g. the
// partition leaders). They implement the BrokerMessage interface to
// delegate the routing decision to each message type.
broker, err := m.Broker(state.layout)
if err != nil {
return reject(err)
}
brokerID = broker.ID
case protocol.GroupMessage:
// Some requests are supposed to be sent to a group coordinator,
// look up which broker is currently the coordinator for the group
// so we can get a connection to that broker.
//
// TODO: should we cache the coordinator info?
p := p.sendRequest(ctx, &findcoordinator.Request{Key: m.Group()}, state)
r, err := p.await(ctx)
if err != nil {
return reject(err)
}
brokerID = r.(*findcoordinator.Response).NodeID
case protocol.TransactionalMessage:
p := p.sendRequest(ctx, &findcoordinator.Request{
Key: m.Transaction(),
KeyType: int8(CoordinatorKeyTypeTransaction),
}, state)
r, err := p.await(ctx)
if err != nil {
return reject(err)
}
brokerID = r.(*findcoordinator.Response).NodeID
}
var c *conn
var err error
if brokerID >= 0 {
c, err = p.grabBrokerConn(ctx, brokerID)
} else {
c, err = p.grabClusterConn(ctx)
}
if err != nil {
return reject(err)
}
res := make(async, 1)
c.reqs <- connRequest{
ctx: ctx,
req: req,
res: res,
}
return res
}
func filterMetadataResponse(req *meta.Request, res *meta.Response) *meta.Response {
ret := *res
if req.TopicNames != nil {
ret.Topics = make([]meta.ResponseTopic, len(req.TopicNames))
for i, topicName := range req.TopicNames {
j, ok := findMetadataTopic(res.Topics, topicName)
if ok {
ret.Topics[i] = res.Topics[j]
} else {
ret.Topics[i] = meta.ResponseTopic{
ErrorCode: int16(UnknownTopicOrPartition),
Name: topicName,
}
}
}
}
return &ret
}
func findMetadataTopic(topics []meta.ResponseTopic, topicName string) (int, bool) {
i := sort.Search(len(topics), func(i int) bool {
return topics[i].Name >= topicName
})
return i, i >= 0 && i < len(topics) && topics[i].Name == topicName
}
func sortMetadataBrokers(brokers []meta.ResponseBroker) {
sort.Slice(brokers, func(i, j int) bool {
return brokers[i].NodeID < brokers[j].NodeID
})
}
func sortMetadataTopics(topics []meta.ResponseTopic) {
sort.Slice(topics, func(i, j int) bool {
return topics[i].Name < topics[j].Name
})
}
func sortMetadataPartitions(partitions []meta.ResponsePartition) {
sort.Slice(partitions, func(i, j int) bool {
return partitions[i].PartitionIndex < partitions[j].PartitionIndex
})
}
func makeLayout(metadataResponse *meta.Response) protocol.Cluster {
layout := protocol.Cluster{
Controller: metadataResponse.ControllerID,
Brokers: make(map[int32]protocol.Broker),
Topics: make(map[string]protocol.Topic),
}
for _, broker := range metadataResponse.Brokers {
layout.Brokers[broker.NodeID] = protocol.Broker{
Rack: broker.Rack,
Host: broker.Host,
Port: broker.Port,
ID: broker.NodeID,
}
}
for _, topic := range metadataResponse.Topics {
if topic.IsInternal {
continue // TODO: do we need to expose those?
}
layout.Topics[topic.Name] = protocol.Topic{
Name: topic.Name,
Error: topic.ErrorCode,
Partitions: makePartitions(topic.Partitions),
}
}
return layout
}
func makePartitions(metadataPartitions []meta.ResponsePartition) map[int32]protocol.Partition {
protocolPartitions := make(map[int32]protocol.Partition, len(metadataPartitions))
numBrokerIDs := 0
for _, p := range metadataPartitions {
numBrokerIDs += len(p.ReplicaNodes) + len(p.IsrNodes) + len(p.OfflineReplicas)
}
// Reduce the memory footprint a bit by allocating a single buffer to write
// all broker ids.
brokerIDs := make([]int32, 0, numBrokerIDs)
for _, p := range metadataPartitions {
var rep, isr, off []int32
brokerIDs, rep = appendBrokerIDs(brokerIDs, p.ReplicaNodes)
brokerIDs, isr = appendBrokerIDs(brokerIDs, p.IsrNodes)
brokerIDs, off = appendBrokerIDs(brokerIDs, p.OfflineReplicas)
protocolPartitions[p.PartitionIndex] = protocol.Partition{
ID: p.PartitionIndex,
Error: p.ErrorCode,
Leader: p.LeaderID,
Replicas: rep,
ISR: isr,
Offline: off,
}
}
return protocolPartitions
}
func appendBrokerIDs(ids, brokers []int32) ([]int32, []int32) {
i := len(ids)
ids = append(ids, brokers...)
return ids, ids[i:len(ids):len(ids)]
}
func (p *connPool) newConnGroup(a net.Addr) *connGroup {
return &connGroup{
addr: a,
pool: p,
broker: Broker{
ID: -1,
},
}
}
func (p *connPool) newBrokerConnGroup(broker Broker) *connGroup {
return &connGroup{
addr: &networkAddress{
network: "tcp",
address: net.JoinHostPort(broker.Host, strconv.Itoa(broker.Port)),
},
pool: p,
broker: broker,
}
}
type connRequest struct {
ctx context.Context
req Request
res async
}
// The promise interface is used as a message passing abstraction to coordinate
// between goroutines that handle requests and responses.
type promise interface {
// Waits until the promise is resolved, rejected, or the context canceled.
await(context.Context) (Response, error)
}
// async is an implementation of the promise interface which supports resolving
// or rejecting the await call asynchronously.
type async chan interface{}
func (p async) await(ctx context.Context) (Response, error) {
select {
case x := <-p:
switch v := x.(type) {
case nil:
return nil, nil // A nil response is ok (e.g. when RequiredAcks is None)
case Response:
return v, nil
case error:
return nil, v
default:
panic(fmt.Errorf("BUG: promise resolved with impossible value of type %T", v))
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (p async) resolve(res Response) { p <- res }
func (p async) reject(err error) { p <- err }
// rejected is an implementation of the promise interface which is always
// returns an error. Values of this type are constructed using the reject
// function.
type rejected struct{ err error }
func reject(err error) promise { return &rejected{err: err} }
func (p *rejected) await(ctx context.Context) (Response, error) {
return nil, p.err
}
// joined is an implementation of the promise interface which merges results
// from multiple promises into one await call using a merger.
type joined struct {
promises []promise
requests []Request
merger protocol.Merger
}
func join(promises []promise, requests []Request, merger protocol.Merger) promise {
return &joined{
promises: promises,
requests: requests,
merger: merger,
}
}
func (p *joined) await(ctx context.Context) (Response, error) {
results := make([]interface{}, len(p.promises))
for i, sub := range p.promises {
m, err := sub.await(ctx)
if err != nil {
results[i] = err
} else {
results[i] = m
}
}
return p.merger.Merge(p.requests, results)
}
// Default dialer used by the transport connections when no Dial function
// was configured by the program.
var defaultDialer = net.Dialer{
Timeout: 3 * time.Second,
DualStack: true,
}
// connGroup represents a logical connection group to a kafka broker. The
// actual network connections are lazily open before sending requests, and
// closed if they are unused for longer than the idle timeout.
type connGroup struct {
addr net.Addr
broker Broker
// Immutable state of the connection.
pool *connPool
// Shared state of the connection, this is synchronized on the mutex through
// calls to the synchronized method. Both goroutines of the connection share
// the state maintained in these fields.
mutex sync.Mutex
closed bool
idleConns []*conn // stack of idle connections
}
func (g *connGroup) closeIdleConns() {
g.mutex.Lock()
conns := g.idleConns
g.idleConns = nil
g.closed = true
g.mutex.Unlock()
for _, c := range conns {
c.close()
}
}
func (g *connGroup) grabConnOrConnect(ctx context.Context) (*conn, error) {
rslv := g.pool.resolver
addr := g.addr
var c *conn
if rslv == nil {
c = g.grabConn()
} else {
var err error
broker := g.broker
if broker.ID < 0 {
host, port, err := splitHostPortNumber(addr.String())
if err != nil {
return nil, err
}
broker.Host = host
broker.Port = port
}
ipAddrs, err := rslv.LookupBrokerIPAddr(ctx, broker)
if err != nil {
return nil, err
}
for _, ipAddr := range ipAddrs {
network := addr.Network()
address := net.JoinHostPort(ipAddr.String(), strconv.Itoa(broker.Port))
if c = g.grabConnTo(network, address); c != nil {
break
}
}
}
if c == nil {
connChan := make(chan *conn)
errChan := make(chan error)
go func() {
c, err := g.connect(ctx, addr)
if err != nil {
select {
case errChan <- err:
case <-ctx.Done():
}
} else {
select {
case connChan <- c:
case <-ctx.Done():
if !g.releaseConn(c) {
c.close()
}
}
}
}()
select {
case c = <-connChan:
case err := <-errChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
return c, nil
}
func (g *connGroup) grabConnTo(network, address string) *conn {
g.mutex.Lock()
defer g.mutex.Unlock()
for i := len(g.idleConns) - 1; i >= 0; i-- {
c := g.idleConns[i]
if c.network == network && c.address == address {
copy(g.idleConns[i:], g.idleConns[i+1:])
n := len(g.idleConns) - 1
g.idleConns[n] = nil
g.idleConns = g.idleConns[:n]
if c.timer != nil {
c.timer.Stop()
}
return c
}
}
return nil
}
func (g *connGroup) grabConn() *conn {
g.mutex.Lock()
defer g.mutex.Unlock()
if len(g.idleConns) == 0 {
return nil
}
n := len(g.idleConns) - 1
c := g.idleConns[n]
g.idleConns[n] = nil
g.idleConns = g.idleConns[:n]
if c.timer != nil {
c.timer.Stop()
}
return c
}
func (g *connGroup) removeConn(c *conn) bool {
g.mutex.Lock()
defer g.mutex.Unlock()
if c.timer != nil {
c.timer.Stop()
}
for i, x := range g.idleConns {
if x == c {
copy(g.idleConns[i:], g.idleConns[i+1:])
n := len(g.idleConns) - 1
g.idleConns[n] = nil
g.idleConns = g.idleConns[:n]
return true
}
}
return false
}
func (g *connGroup) releaseConn(c *conn) bool {
idleTimeout := g.pool.idleTimeout
g.mutex.Lock()
defer g.mutex.Unlock()
if g.closed {
return false
}
if c.timer != nil {
c.timer.Reset(idleTimeout)
} else {
c.timer = time.AfterFunc(idleTimeout, func() {
if g.removeConn(c) {
c.close()
}
})
}
g.idleConns = append(g.idleConns, c)
return true
}
func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) {
deadline := time.Now().Add(g.pool.dialTimeout)
ctx, cancel := context.WithDeadline(ctx, deadline)
defer cancel()
network := strings.Split(addr.Network(), ",")
address := strings.Split(addr.String(), ",")
var netConn net.Conn
var netAddr net.Addr
var err error
if len(address) > 1 {
// Shuffle the list of addresses to randomize the order in which
// connections are attempted. This prevents routing all connections
// to the first broker (which will usually succeed).
rand.Shuffle(len(address), func(i, j int) {
network[i], network[j] = network[j], network[i]
address[i], address[j] = address[j], address[i]
})
}
for i := range address {
netConn, err = g.pool.dial(ctx, network[i], address[i])
if err == nil {
netAddr = &networkAddress{
network: network[i],
address: address[i],
}
break
}
}
if err != nil {
return nil, err
}
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if tlsConfig := g.pool.tls; tlsConfig != nil {
if tlsConfig.ServerName == "" {
host, _ := splitHostPort(netAddr.String())
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = host
}
netConn = tls.Client(netConn, tlsConfig)
}
pc := protocol.NewConn(netConn, g.pool.clientID)
pc.SetDeadline(deadline)
r, err := pc.RoundTrip(new(apiversions.Request))
if err != nil {
return nil, err
}
res := r.(*apiversions.Response)
ver := make(map[protocol.ApiKey]int16, len(res.ApiKeys))
if res.ErrorCode != 0 {
return nil, fmt.Errorf("negotating API versions with kafka broker at %s: %w", g.addr, Error(res.ErrorCode))
}
for _, r := range res.ApiKeys {
apiKey := protocol.ApiKey(r.ApiKey)
ver[apiKey] = apiKey.SelectVersion(r.MinVersion, r.MaxVersion)
}
pc.SetVersions(ver)
pc.SetDeadline(time.Time{})
if g.pool.sasl != nil {
host, port, err := splitHostPortNumber(netAddr.String())
if err != nil {
return nil, err
}
metadata := &sasl.Metadata{
Host: host,
Port: port,
}
if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, g.pool.sasl); err != nil {
return nil, err
}
}
reqs := make(chan connRequest)
c := &conn{
network: netAddr.Network(),
address: netAddr.String(),
reqs: reqs,
group: g,
}
go c.run(pc, reqs)
netConn = nil
return c, nil
}
type conn struct {
reqs chan<- connRequest
network string
address string
once sync.Once
group *connGroup
timer *time.Timer
}
func (c *conn) close() {
c.once.Do(func() { close(c.reqs) })
}
func (c *conn) run(pc *protocol.Conn, reqs <-chan connRequest) {
defer pc.Close()
for cr := range reqs {
r, err := c.roundTrip(cr.ctx, pc, cr.req)
if err != nil {
cr.res.reject(err)
if !errors.Is(err, protocol.ErrNoRecord) {
break
}
} else {
cr.res.resolve(r)
}
if !c.group.releaseConn(c) {
break
}
}
}
func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (Response, error) {
pprof.SetGoroutineLabels(ctx)
defer pprof.SetGoroutineLabels(context.Background())
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
pc.SetDeadline(deadline)
defer pc.SetDeadline(time.Time{})
}
return pc.RoundTrip(req)
}
// authenticateSASL performs all of the required requests to authenticate this
// connection. If any step fails, this function returns with an error. A nil
// error indicates successful authentication.
func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error {
if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil {
return err
}
sess, state, err := mechanism.Start(ctx)
if err != nil {
return err
}
for completed := false; !completed; {
challenge, err := saslAuthenticateRoundTrip(pc, state)
if err != nil {
if errors.Is(err, io.EOF) {
// the broker may communicate a failed exchange by closing the
// connection (esp. in the case where we're passing opaque sasl
// data over the wire since there's no protocol info).
return SASLAuthenticationFailed
}
return err
}
completed, state, err = sess.Next(ctx, challenge)
if err != nil {
return err
}
}
return nil
}
// saslHandshake sends the SASL handshake message. This will determine whether
// the Mechanism is supported by the cluster. If it's not, this function will
// error out with UnsupportedSASLMechanism.
//
// If the mechanism is unsupported, the handshake request will reply with the
// list of the cluster's configured mechanisms, which could potentially be used
// to facilitate negotiation. At the moment, we are not negotiating the
// mechanism as we believe that brokers are usually known to the client, and
// therefore the client should already know which mechanisms are supported.
//
// See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake
func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error {
msg, err := pc.RoundTrip(&saslhandshake.Request{
Mechanism: mechanism,
})
if err != nil {
return err
}
res := msg.(*saslhandshake.Response)
if res.ErrorCode != 0 {
err = Error(res.ErrorCode)
}
return err
}
// saslAuthenticate sends the SASL authenticate message. This function must
// be immediately preceded by a successful saslHandshake.
//
// See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate
func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) {
msg, err := pc.RoundTrip(&saslauthenticate.Request{
AuthBytes: data,
})
if err != nil {
return nil, err
}
res := msg.(*saslauthenticate.Response)
if res.ErrorCode != 0 {
err = makeError(res.ErrorCode, res.ErrorMessage)
}
return res.AuthBytes, err
}
var _ RoundTripper = (*Transport)(nil)