cubefs

Форк
0
1776 строк · 51.0 Кб
1
package sarama
2

3
import (
4
	"crypto/tls"
5
	"encoding/binary"
6
	"errors"
7
	"fmt"
8
	"io"
9
	"math/rand"
10
	"net"
11
	"sort"
12
	"strconv"
13
	"strings"
14
	"sync"
15
	"sync/atomic"
16
	"time"
17

18
	"github.com/rcrowley/go-metrics"
19
)
20

21
// Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
22
type Broker struct {
23
	conf *Config
24
	rack *string
25

26
	id            int32
27
	addr          string
28
	correlationID int32
29
	conn          net.Conn
30
	connErr       error
31
	lock          sync.Mutex
32
	opened        int32
33
	responses     chan *responsePromise
34
	done          chan bool
35

36
	registeredMetrics map[string]struct{}
37

38
	incomingByteRate       metrics.Meter
39
	requestRate            metrics.Meter
40
	requestSize            metrics.Histogram
41
	requestLatency         metrics.Histogram
42
	outgoingByteRate       metrics.Meter
43
	responseRate           metrics.Meter
44
	responseSize           metrics.Histogram
45
	requestsInFlight       metrics.Counter
46
	brokerIncomingByteRate metrics.Meter
47
	brokerRequestRate      metrics.Meter
48
	brokerRequestSize      metrics.Histogram
49
	brokerRequestLatency   metrics.Histogram
50
	brokerOutgoingByteRate metrics.Meter
51
	brokerResponseRate     metrics.Meter
52
	brokerResponseSize     metrics.Histogram
53
	brokerRequestsInFlight metrics.Counter
54
	brokerThrottleTime     metrics.Histogram
55

56
	kerberosAuthenticator               GSSAPIKerberosAuth
57
	clientSessionReauthenticationTimeMs int64
58
}
59

60
// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
61
type SASLMechanism string
62

63
const (
64
	// SASLTypeOAuth represents the SASL/OAUTHBEARER mechanism (Kafka 2.0.0+)
65
	SASLTypeOAuth = "OAUTHBEARER"
66
	// SASLTypePlaintext represents the SASL/PLAIN mechanism
67
	SASLTypePlaintext = "PLAIN"
68
	// SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
69
	SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
70
	// SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
71
	SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
72
	SASLTypeGSSAPI      = "GSSAPI"
73
	// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
74
	// server negotiate SASL auth using opaque packets.
75
	SASLHandshakeV0 = int16(0)
76
	// SASLHandshakeV1 is v1 of the Kafka SASL handshake protocol. Client and
77
	// server negotiate SASL by wrapping tokens with Kafka protocol headers.
78
	SASLHandshakeV1 = int16(1)
79
	// SASLExtKeyAuth is the reserved extension key name sent as part of the
80
	// SASL/OAUTHBEARER initial client response
81
	SASLExtKeyAuth = "auth"
82
)
83

84
// AccessToken contains an access token used to authenticate a
85
// SASL/OAUTHBEARER client along with associated metadata.
86
type AccessToken struct {
87
	// Token is the access token payload.
88
	Token string
89
	// Extensions is a optional map of arbitrary key-value pairs that can be
90
	// sent with the SASL/OAUTHBEARER initial client response. These values are
91
	// ignored by the SASL server if they are unexpected. This feature is only
92
	// supported by Kafka >= 2.1.0.
93
	Extensions map[string]string
94
}
95

96
// AccessTokenProvider is the interface that encapsulates how implementors
97
// can generate access tokens for Kafka broker authentication.
98
type AccessTokenProvider interface {
99
	// Token returns an access token. The implementation should ensure token
100
	// reuse so that multiple calls at connect time do not create multiple
101
	// tokens. The implementation should also periodically refresh the token in
102
	// order to guarantee that each call returns an unexpired token.  This
103
	// method should not block indefinitely--a timeout error should be returned
104
	// after a short period of inactivity so that the broker connection logic
105
	// can log debugging information and retry.
106
	Token() (*AccessToken, error)
107
}
108

109
// SCRAMClient is a an interface to a SCRAM
110
// client implementation.
111
type SCRAMClient interface {
112
	// Begin prepares the client for the SCRAM exchange
113
	// with the server with a user name and a password
114
	Begin(userName, password, authzID string) error
115
	// Step steps client through the SCRAM exchange. It is
116
	// called repeatedly until it errors or `Done` returns true.
117
	Step(challenge string) (response string, err error)
118
	// Done should return true when the SCRAM conversation
119
	// is over.
120
	Done() bool
121
}
122

123
type responsePromise struct {
124
	requestTime   time.Time
125
	correlationID int32
126
	headerVersion int16
127
	handler       func([]byte, error)
128
	packets       chan []byte
129
	errors        chan error
130
}
131

132
func (p *responsePromise) handle(packets []byte, err error) {
133
	// Use callback when provided
134
	if p.handler != nil {
135
		p.handler(packets, err)
136
		return
137
	}
138
	// Otherwise fallback to using channels
139
	if err != nil {
140
		p.errors <- err
141
		return
142
	}
143
	p.packets <- packets
144
}
145

146
// NewBroker creates and returns a Broker targeting the given host:port address.
147
// This does not attempt to actually connect, you have to call Open() for that.
148
func NewBroker(addr string) *Broker {
149
	return &Broker{id: -1, addr: addr}
150
}
151

152
// Open tries to connect to the Broker if it is not already connected or connecting, but does not block
153
// waiting for the connection to complete. This means that any subsequent operations on the broker will
154
// block waiting for the connection to succeed or fail. To get the effect of a fully synchronous Open call,
155
// follow it by a call to Connected(). The only errors Open will return directly are ConfigurationError or
156
// AlreadyConnected. If conf is nil, the result of NewConfig() is used.
157
func (b *Broker) Open(conf *Config) error {
158
	if !atomic.CompareAndSwapInt32(&b.opened, 0, 1) {
159
		return ErrAlreadyConnected
160
	}
161

162
	if conf == nil {
163
		conf = NewConfig()
164
	}
165

166
	err := conf.Validate()
167
	if err != nil {
168
		return err
169
	}
170

171
	usingApiVersionsRequests := conf.Version.IsAtLeast(V2_4_0_0) && conf.ApiVersionsRequest
172

173
	b.lock.Lock()
174

175
	go withRecover(func() {
176
		defer func() {
177
			b.lock.Unlock()
178

179
			// Send an ApiVersionsRequest to identify the client (KIP-511).
180
			// Ideally Sarama would use the response to control protocol versions,
181
			// but for now just fire-and-forget just to send
182
			if usingApiVersionsRequests {
183
				_, err = b.ApiVersions(&ApiVersionsRequest{
184
					Version:               3,
185
					ClientSoftwareName:    defaultClientSoftwareName,
186
					ClientSoftwareVersion: version(),
187
				})
188
				if err != nil {
189
					Logger.Printf("Error while sending ApiVersionsRequest to broker %s: %s\n", b.addr, err)
190
				}
191
			}
192
		}()
193
		dialer := conf.getDialer()
194
		b.conn, b.connErr = dialer.Dial("tcp", b.addr)
195
		if b.connErr != nil {
196
			Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr)
197
			b.conn = nil
198
			atomic.StoreInt32(&b.opened, 0)
199
			return
200
		}
201
		if conf.Net.TLS.Enable {
202
			b.conn = tls.Client(b.conn, validServerNameTLS(b.addr, conf.Net.TLS.Config))
203
		}
204

205
		b.conn = newBufConn(b.conn)
206
		b.conf = conf
207

208
		// Create or reuse the global metrics shared between brokers
209
		b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry)
210
		b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry)
211
		b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry)
212
		b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry)
213
		b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry)
214
		b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry)
215
		b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry)
216
		b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", conf.MetricRegistry)
217
		// Do not gather metrics for seeded broker (only used during bootstrap) because they share
218
		// the same id (-1) and are already exposed through the global metrics above
219
		if b.id >= 0 && !metrics.UseNilMetrics {
220
			b.registerMetrics()
221
		}
222

223
		if conf.Net.SASL.Enable {
224
			b.connErr = b.authenticateViaSASL()
225

226
			if b.connErr != nil {
227
				err = b.conn.Close()
228
				if err == nil {
229
					DebugLogger.Printf("Closed connection to broker %s\n", b.addr)
230
				} else {
231
					Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
232
				}
233
				b.conn = nil
234
				atomic.StoreInt32(&b.opened, 0)
235
				return
236
			}
237
		}
238

239
		b.done = make(chan bool)
240
		b.responses = make(chan *responsePromise, b.conf.Net.MaxOpenRequests-1)
241

242
		if b.id >= 0 {
243
			DebugLogger.Printf("Connected to broker at %s (registered as #%d)\n", b.addr, b.id)
244
		} else {
245
			DebugLogger.Printf("Connected to broker at %s (unregistered)\n", b.addr)
246
		}
247
		go withRecover(b.responseReceiver)
248
	})
249

250
	return nil
251
}
252

253
// Connected returns true if the broker is connected and false otherwise. If the broker is not
254
// connected but it had tried to connect, the error from that connection attempt is also returned.
255
func (b *Broker) Connected() (bool, error) {
256
	b.lock.Lock()
257
	defer b.lock.Unlock()
258

259
	return b.conn != nil, b.connErr
260
}
261

262
// TLSConnectionState returns the client's TLS connection state. The second return value is false if this is not a tls connection or the connection has not yet been established.
263
func (b *Broker) TLSConnectionState() (state tls.ConnectionState, ok bool) {
264
	b.lock.Lock()
265
	defer b.lock.Unlock()
266

267
	if b.conn == nil {
268
		return state, false
269
	}
270
	conn := b.conn
271
	if bconn, ok := b.conn.(*bufConn); ok {
272
		conn = bconn.Conn
273
	}
274
	if tc, ok := conn.(*tls.Conn); ok {
275
		return tc.ConnectionState(), true
276
	}
277
	return state, false
278
}
279

280
// Close closes the broker resources
281
func (b *Broker) Close() error {
282
	b.lock.Lock()
283
	defer b.lock.Unlock()
284

285
	if b.conn == nil {
286
		return ErrNotConnected
287
	}
288

289
	close(b.responses)
290
	<-b.done
291

292
	err := b.conn.Close()
293

294
	b.conn = nil
295
	b.connErr = nil
296
	b.done = nil
297
	b.responses = nil
298

299
	b.unregisterMetrics()
300

301
	if err == nil {
302
		DebugLogger.Printf("Closed connection to broker %s\n", b.addr)
303
	} else {
304
		Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
305
	}
306

307
	atomic.StoreInt32(&b.opened, 0)
308

309
	return err
310
}
311

312
// ID returns the broker ID retrieved from Kafka's metadata, or -1 if that is not known.
313
func (b *Broker) ID() int32 {
314
	return b.id
315
}
316

317
// Addr returns the broker address as either retrieved from Kafka's metadata or passed to NewBroker.
318
func (b *Broker) Addr() string {
319
	return b.addr
320
}
321

322
// Rack returns the broker's rack as retrieved from Kafka's metadata or the
323
// empty string if it is not known.  The returned value corresponds to the
324
// broker's broker.rack configuration setting.  Requires protocol version to be
325
// at least v0.10.0.0.
326
func (b *Broker) Rack() string {
327
	if b.rack == nil {
328
		return ""
329
	}
330
	return *b.rack
331
}
332

333
// GetMetadata send a metadata request and returns a metadata response or error
334
func (b *Broker) GetMetadata(request *MetadataRequest) (*MetadataResponse, error) {
335
	response := new(MetadataResponse)
336

337
	err := b.sendAndReceive(request, response)
338
	if err != nil {
339
		return nil, err
340
	}
341

342
	return response, nil
343
}
344

345
// GetConsumerMetadata send a consumer metadata request and returns a consumer metadata response or error
346
func (b *Broker) GetConsumerMetadata(request *ConsumerMetadataRequest) (*ConsumerMetadataResponse, error) {
347
	response := new(ConsumerMetadataResponse)
348

349
	err := b.sendAndReceive(request, response)
350
	if err != nil {
351
		return nil, err
352
	}
353

354
	return response, nil
355
}
356

357
// FindCoordinator sends a find coordinate request and returns a response or error
358
func (b *Broker) FindCoordinator(request *FindCoordinatorRequest) (*FindCoordinatorResponse, error) {
359
	response := new(FindCoordinatorResponse)
360

361
	err := b.sendAndReceive(request, response)
362
	if err != nil {
363
		return nil, err
364
	}
365

366
	return response, nil
367
}
368

369
// GetAvailableOffsets return an offset response or error
370
func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, error) {
371
	response := new(OffsetResponse)
372

373
	err := b.sendAndReceive(request, response)
374
	if err != nil {
375
		return nil, err
376
	}
377

378
	return response, nil
379
}
380

381
// ProduceCallback function is called once the produce response has been parsed
382
// or could not be read.
383
type ProduceCallback func(*ProduceResponse, error)
384

385
// AsyncProduce sends a produce request and eventually call the provided callback
386
// with a produce response or an error.
387
//
388
// Waiting for the response is generally not blocking on the contrary to using Produce.
389
// If the maximum number of in flight request configured is reached then
390
// the request will be blocked till a previous response is received.
391
//
392
// When configured with RequiredAcks == NoResponse, the callback will not be invoked.
393
// If an error is returned because the request could not be sent then the callback
394
// will not be invoked either.
395
//
396
// Make sure not to Close the broker in the callback as it will lead to a deadlock.
397
func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error {
398
	needAcks := request.RequiredAcks != NoResponse
399
	// Use a nil promise when no acks is required
400
	var promise *responsePromise
401

402
	if needAcks {
403
		// Create ProduceResponse early to provide the header version
404
		res := new(ProduceResponse)
405
		promise = &responsePromise{
406
			headerVersion: res.headerVersion(),
407
			// Packets will be converted to a ProduceResponse in the responseReceiver goroutine
408
			handler: func(packets []byte, err error) {
409
				if err != nil {
410
					// Failed request
411
					cb(nil, err)
412
					return
413
				}
414

415
				if err := versionedDecode(packets, res, request.version()); err != nil {
416
					// Malformed response
417
					cb(nil, err)
418
					return
419
				}
420

421
				// Wellformed response
422
				b.updateThrottleMetric(res.ThrottleTime)
423
				cb(res, nil)
424
			},
425
		}
426
	}
427

428
	return b.sendWithPromise(request, promise)
429
}
430

431
//Produce returns a produce response or error
432
func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) {
433
	var (
434
		response *ProduceResponse
435
		err      error
436
	)
437

438
	if request.RequiredAcks == NoResponse {
439
		err = b.sendAndReceive(request, nil)
440
	} else {
441
		response = new(ProduceResponse)
442
		err = b.sendAndReceive(request, response)
443
		b.updateThrottleMetric(response.ThrottleTime)
444
	}
445

446
	if err != nil {
447
		return nil, err
448
	}
449

450
	return response, nil
451
}
452

453
// Fetch returns a FetchResponse or error
454
func (b *Broker) Fetch(request *FetchRequest) (*FetchResponse, error) {
455
	response := new(FetchResponse)
456

457
	err := b.sendAndReceive(request, response)
458
	if err != nil {
459
		return nil, err
460
	}
461

462
	return response, nil
463
}
464

465
// CommitOffset return an Offset commit response or error
466
func (b *Broker) CommitOffset(request *OffsetCommitRequest) (*OffsetCommitResponse, error) {
467
	response := new(OffsetCommitResponse)
468

469
	err := b.sendAndReceive(request, response)
470
	if err != nil {
471
		return nil, err
472
	}
473

474
	return response, nil
475
}
476

477
// FetchOffset returns an offset fetch response or error
478
func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse, error) {
479
	response := new(OffsetFetchResponse)
480
	response.Version = request.Version // needed to handle the two header versions
481

482
	err := b.sendAndReceive(request, response)
483
	if err != nil {
484
		return nil, err
485
	}
486

487
	return response, nil
488
}
489

490
// JoinGroup returns a join group response or error
491
func (b *Broker) JoinGroup(request *JoinGroupRequest) (*JoinGroupResponse, error) {
492
	response := new(JoinGroupResponse)
493

494
	err := b.sendAndReceive(request, response)
495
	if err != nil {
496
		return nil, err
497
	}
498

499
	return response, nil
500
}
501

502
// SyncGroup returns a sync group response or error
503
func (b *Broker) SyncGroup(request *SyncGroupRequest) (*SyncGroupResponse, error) {
504
	response := new(SyncGroupResponse)
505

506
	err := b.sendAndReceive(request, response)
507
	if err != nil {
508
		return nil, err
509
	}
510

511
	return response, nil
512
}
513

514
// LeaveGroup return a leave group response or error
515
func (b *Broker) LeaveGroup(request *LeaveGroupRequest) (*LeaveGroupResponse, error) {
516
	response := new(LeaveGroupResponse)
517

518
	err := b.sendAndReceive(request, response)
519
	if err != nil {
520
		return nil, err
521
	}
522

523
	return response, nil
524
}
525

526
// Heartbeat returns a heartbeat response or error
527
func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error) {
528
	response := new(HeartbeatResponse)
529

530
	err := b.sendAndReceive(request, response)
531
	if err != nil {
532
		return nil, err
533
	}
534

535
	return response, nil
536
}
537

538
// ListGroups return a list group response or error
539
func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) {
540
	response := new(ListGroupsResponse)
541

542
	err := b.sendAndReceive(request, response)
543
	if err != nil {
544
		return nil, err
545
	}
546

547
	return response, nil
548
}
549

550
// DescribeGroups return describe group response or error
551
func (b *Broker) DescribeGroups(request *DescribeGroupsRequest) (*DescribeGroupsResponse, error) {
552
	response := new(DescribeGroupsResponse)
553

554
	err := b.sendAndReceive(request, response)
555
	if err != nil {
556
		return nil, err
557
	}
558

559
	return response, nil
560
}
561

562
// ApiVersions return api version response or error
563
func (b *Broker) ApiVersions(request *ApiVersionsRequest) (*ApiVersionsResponse, error) {
564
	response := new(ApiVersionsResponse)
565

566
	err := b.sendAndReceive(request, response)
567
	if err != nil {
568
		return nil, err
569
	}
570

571
	return response, nil
572
}
573

574
// CreateTopics send a create topic request and returns create topic response
575
func (b *Broker) CreateTopics(request *CreateTopicsRequest) (*CreateTopicsResponse, error) {
576
	response := new(CreateTopicsResponse)
577

578
	err := b.sendAndReceive(request, response)
579
	if err != nil {
580
		return nil, err
581
	}
582

583
	return response, nil
584
}
585

586
// DeleteTopics sends a delete topic request and returns delete topic response
587
func (b *Broker) DeleteTopics(request *DeleteTopicsRequest) (*DeleteTopicsResponse, error) {
588
	response := new(DeleteTopicsResponse)
589

590
	err := b.sendAndReceive(request, response)
591
	if err != nil {
592
		return nil, err
593
	}
594

595
	return response, nil
596
}
597

598
// CreatePartitions sends a create partition request and returns create
599
// partitions response or error
600
func (b *Broker) CreatePartitions(request *CreatePartitionsRequest) (*CreatePartitionsResponse, error) {
601
	response := new(CreatePartitionsResponse)
602

603
	err := b.sendAndReceive(request, response)
604
	if err != nil {
605
		return nil, err
606
	}
607

608
	return response, nil
609
}
610

611
// AlterPartitionReassignments sends a alter partition reassignments request and
612
// returns alter partition reassignments response
613
func (b *Broker) AlterPartitionReassignments(request *AlterPartitionReassignmentsRequest) (*AlterPartitionReassignmentsResponse, error) {
614
	response := new(AlterPartitionReassignmentsResponse)
615

616
	err := b.sendAndReceive(request, response)
617
	if err != nil {
618
		return nil, err
619
	}
620

621
	return response, nil
622
}
623

624
// ListPartitionReassignments sends a list partition reassignments request and
625
// returns list partition reassignments response
626
func (b *Broker) ListPartitionReassignments(request *ListPartitionReassignmentsRequest) (*ListPartitionReassignmentsResponse, error) {
627
	response := new(ListPartitionReassignmentsResponse)
628

629
	err := b.sendAndReceive(request, response)
630
	if err != nil {
631
		return nil, err
632
	}
633

634
	return response, nil
635
}
636

637
// DeleteRecords send a request to delete records and return delete record
638
// response or error
639
func (b *Broker) DeleteRecords(request *DeleteRecordsRequest) (*DeleteRecordsResponse, error) {
640
	response := new(DeleteRecordsResponse)
641

642
	err := b.sendAndReceive(request, response)
643
	if err != nil {
644
		return nil, err
645
	}
646

647
	return response, nil
648
}
649

650
// DescribeAcls sends a describe acl request and returns a response or error
651
func (b *Broker) DescribeAcls(request *DescribeAclsRequest) (*DescribeAclsResponse, error) {
652
	response := new(DescribeAclsResponse)
653

654
	err := b.sendAndReceive(request, response)
655
	if err != nil {
656
		return nil, err
657
	}
658

659
	return response, nil
660
}
661

662
// CreateAcls sends a create acl request and returns a response or error
663
func (b *Broker) CreateAcls(request *CreateAclsRequest) (*CreateAclsResponse, error) {
664
	response := new(CreateAclsResponse)
665

666
	err := b.sendAndReceive(request, response)
667
	if err != nil {
668
		return nil, err
669
	}
670

671
	errs := make([]error, 0)
672
	for _, res := range response.AclCreationResponses {
673
		if !errors.Is(res.Err, ErrNoError) {
674
			errs = append(errs, res.Err)
675
		}
676
	}
677

678
	if len(errs) > 0 {
679
		return response, Wrap(ErrCreateACLs, errs...)
680
	}
681

682
	return response, nil
683
}
684

685
// DeleteAcls sends a delete acl request and returns a response or error
686
func (b *Broker) DeleteAcls(request *DeleteAclsRequest) (*DeleteAclsResponse, error) {
687
	response := new(DeleteAclsResponse)
688

689
	err := b.sendAndReceive(request, response)
690
	if err != nil {
691
		return nil, err
692
	}
693

694
	return response, nil
695
}
696

697
// InitProducerID sends an init producer request and returns a response or error
698
func (b *Broker) InitProducerID(request *InitProducerIDRequest) (*InitProducerIDResponse, error) {
699
	response := new(InitProducerIDResponse)
700

701
	err := b.sendAndReceive(request, response)
702
	if err != nil {
703
		return nil, err
704
	}
705

706
	return response, nil
707
}
708

709
// AddPartitionsToTxn send a request to add partition to txn and returns
710
// a response or error
711
func (b *Broker) AddPartitionsToTxn(request *AddPartitionsToTxnRequest) (*AddPartitionsToTxnResponse, error) {
712
	response := new(AddPartitionsToTxnResponse)
713

714
	err := b.sendAndReceive(request, response)
715
	if err != nil {
716
		return nil, err
717
	}
718

719
	return response, nil
720
}
721

722
// AddOffsetsToTxn sends a request to add offsets to txn and returns a response
723
// or error
724
func (b *Broker) AddOffsetsToTxn(request *AddOffsetsToTxnRequest) (*AddOffsetsToTxnResponse, error) {
725
	response := new(AddOffsetsToTxnResponse)
726

727
	err := b.sendAndReceive(request, response)
728
	if err != nil {
729
		return nil, err
730
	}
731

732
	return response, nil
733
}
734

735
// EndTxn sends a request to end txn and returns a response or error
736
func (b *Broker) EndTxn(request *EndTxnRequest) (*EndTxnResponse, error) {
737
	response := new(EndTxnResponse)
738

739
	err := b.sendAndReceive(request, response)
740
	if err != nil {
741
		return nil, err
742
	}
743

744
	return response, nil
745
}
746

747
// TxnOffsetCommit sends a request to commit transaction offsets and returns
748
// a response or error
749
func (b *Broker) TxnOffsetCommit(request *TxnOffsetCommitRequest) (*TxnOffsetCommitResponse, error) {
750
	response := new(TxnOffsetCommitResponse)
751

752
	err := b.sendAndReceive(request, response)
753
	if err != nil {
754
		return nil, err
755
	}
756

757
	return response, nil
758
}
759

760
// DescribeConfigs sends a request to describe config and returns a response or
761
// error
762
func (b *Broker) DescribeConfigs(request *DescribeConfigsRequest) (*DescribeConfigsResponse, error) {
763
	response := new(DescribeConfigsResponse)
764

765
	err := b.sendAndReceive(request, response)
766
	if err != nil {
767
		return nil, err
768
	}
769

770
	return response, nil
771
}
772

773
// AlterConfigs sends a request to alter config and return a response or error
774
func (b *Broker) AlterConfigs(request *AlterConfigsRequest) (*AlterConfigsResponse, error) {
775
	response := new(AlterConfigsResponse)
776

777
	err := b.sendAndReceive(request, response)
778
	if err != nil {
779
		return nil, err
780
	}
781

782
	return response, nil
783
}
784

785
// IncrementalAlterConfigs sends a request to incremental alter config and return a response or error
786
func (b *Broker) IncrementalAlterConfigs(request *IncrementalAlterConfigsRequest) (*IncrementalAlterConfigsResponse, error) {
787
	response := new(IncrementalAlterConfigsResponse)
788

789
	err := b.sendAndReceive(request, response)
790
	if err != nil {
791
		return nil, err
792
	}
793

794
	return response, nil
795
}
796

797
// DeleteGroups sends a request to delete groups and returns a response or error
798
func (b *Broker) DeleteGroups(request *DeleteGroupsRequest) (*DeleteGroupsResponse, error) {
799
	response := new(DeleteGroupsResponse)
800

801
	if err := b.sendAndReceive(request, response); err != nil {
802
		return nil, err
803
	}
804

805
	return response, nil
806
}
807

808
// DeleteOffsets sends a request to delete group offsets and returns a response or error
809
func (b *Broker) DeleteOffsets(request *DeleteOffsetsRequest) (*DeleteOffsetsResponse, error) {
810
	response := new(DeleteOffsetsResponse)
811

812
	if err := b.sendAndReceive(request, response); err != nil {
813
		return nil, err
814
	}
815

816
	return response, nil
817
}
818

819
// DescribeLogDirs sends a request to get the broker's log dir paths and sizes
820
func (b *Broker) DescribeLogDirs(request *DescribeLogDirsRequest) (*DescribeLogDirsResponse, error) {
821
	response := new(DescribeLogDirsResponse)
822

823
	err := b.sendAndReceive(request, response)
824
	if err != nil {
825
		return nil, err
826
	}
827

828
	return response, nil
829
}
830

831
// DescribeUserScramCredentials sends a request to get SCRAM users
832
func (b *Broker) DescribeUserScramCredentials(req *DescribeUserScramCredentialsRequest) (*DescribeUserScramCredentialsResponse, error) {
833
	res := new(DescribeUserScramCredentialsResponse)
834

835
	err := b.sendAndReceive(req, res)
836
	if err != nil {
837
		return nil, err
838
	}
839

840
	return res, err
841
}
842

843
func (b *Broker) AlterUserScramCredentials(req *AlterUserScramCredentialsRequest) (*AlterUserScramCredentialsResponse, error) {
844
	res := new(AlterUserScramCredentialsResponse)
845

846
	err := b.sendAndReceive(req, res)
847
	if err != nil {
848
		return nil, err
849
	}
850

851
	return res, nil
852
}
853

854
// DescribeClientQuotas sends a request to get the broker's quotas
855
func (b *Broker) DescribeClientQuotas(request *DescribeClientQuotasRequest) (*DescribeClientQuotasResponse, error) {
856
	response := new(DescribeClientQuotasResponse)
857

858
	err := b.sendAndReceive(request, response)
859
	if err != nil {
860
		return nil, err
861
	}
862

863
	return response, nil
864
}
865

866
// AlterClientQuotas sends a request to alter the broker's quotas
867
func (b *Broker) AlterClientQuotas(request *AlterClientQuotasRequest) (*AlterClientQuotasResponse, error) {
868
	response := new(AlterClientQuotasResponse)
869

870
	err := b.sendAndReceive(request, response)
871
	if err != nil {
872
		return nil, err
873
	}
874

875
	return response, nil
876
}
877

878
// readFull ensures the conn ReadDeadline has been setup before making a
879
// call to io.ReadFull
880
func (b *Broker) readFull(buf []byte) (n int, err error) {
881
	if err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)); err != nil {
882
		return 0, err
883
	}
884

885
	return io.ReadFull(b.conn, buf)
886
}
887

888
// write  ensures the conn WriteDeadline has been setup before making a
889
// call to conn.Write
890
func (b *Broker) write(buf []byte) (n int, err error) {
891
	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
892
		return 0, err
893
	}
894

895
	return b.conn.Write(buf)
896
}
897

898
func (b *Broker) send(rb protocolBody, promiseResponse bool, responseHeaderVersion int16) (*responsePromise, error) {
899
	var promise *responsePromise
900
	if promiseResponse {
901
		// Packets or error will be sent to the following channels
902
		// once the response is received
903
		promise = &responsePromise{
904
			headerVersion: responseHeaderVersion,
905
			packets:       make(chan []byte),
906
			errors:        make(chan error),
907
		}
908
	}
909

910
	if err := b.sendWithPromise(rb, promise); err != nil {
911
		return nil, err
912
	}
913

914
	return promise, nil
915
}
916

917
func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error {
918
	b.lock.Lock()
919
	defer b.lock.Unlock()
920

921
	if b.conn == nil {
922
		if b.connErr != nil {
923
			return b.connErr
924
		}
925
		return ErrNotConnected
926
	}
927

928
	if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs {
929
		err := b.authenticateViaSASL()
930
		if err != nil {
931
			return err
932
		}
933
	}
934

935
	if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
936
		return ErrUnsupportedVersion
937
	}
938

939
	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
940
	buf, err := encode(req, b.conf.MetricRegistry)
941
	if err != nil {
942
		return err
943
	}
944

945
	requestTime := time.Now()
946
	// Will be decremented in responseReceiver (except error or request with NoResponse)
947
	b.addRequestInFlightMetrics(1)
948
	bytes, err := b.write(buf)
949
	b.updateOutgoingCommunicationMetrics(bytes)
950
	if err != nil {
951
		b.addRequestInFlightMetrics(-1)
952
		return err
953
	}
954
	b.correlationID++
955

956
	if promise == nil {
957
		// Record request latency without the response
958
		b.updateRequestLatencyAndInFlightMetrics(time.Since(requestTime))
959
		return nil
960
	}
961

962
	promise.requestTime = requestTime
963
	promise.correlationID = req.correlationID
964
	b.responses <- promise
965

966
	return nil
967
}
968

969
func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
970
	responseHeaderVersion := int16(-1)
971
	if res != nil {
972
		responseHeaderVersion = res.headerVersion()
973
	}
974

975
	promise, err := b.send(req, res != nil, responseHeaderVersion)
976
	if err != nil {
977
		return err
978
	}
979

980
	if promise == nil {
981
		return nil
982
	}
983

984
	select {
985
	case buf := <-promise.packets:
986
		return versionedDecode(buf, res, req.version())
987
	case err = <-promise.errors:
988
		return err
989
	}
990
}
991

992
func (b *Broker) decode(pd packetDecoder, version int16) (err error) {
993
	b.id, err = pd.getInt32()
994
	if err != nil {
995
		return err
996
	}
997

998
	host, err := pd.getString()
999
	if err != nil {
1000
		return err
1001
	}
1002

1003
	port, err := pd.getInt32()
1004
	if err != nil {
1005
		return err
1006
	}
1007

1008
	if version >= 1 {
1009
		b.rack, err = pd.getNullableString()
1010
		if err != nil {
1011
			return err
1012
		}
1013
	}
1014

1015
	b.addr = net.JoinHostPort(host, fmt.Sprint(port))
1016
	if _, _, err := net.SplitHostPort(b.addr); err != nil {
1017
		return err
1018
	}
1019

1020
	return nil
1021
}
1022

1023
func (b *Broker) encode(pe packetEncoder, version int16) (err error) {
1024
	host, portstr, err := net.SplitHostPort(b.addr)
1025
	if err != nil {
1026
		return err
1027
	}
1028

1029
	port, err := strconv.ParseInt(portstr, 10, 32)
1030
	if err != nil {
1031
		return err
1032
	}
1033

1034
	pe.putInt32(b.id)
1035

1036
	err = pe.putString(host)
1037
	if err != nil {
1038
		return err
1039
	}
1040

1041
	pe.putInt32(int32(port))
1042

1043
	if version >= 1 {
1044
		err = pe.putNullableString(b.rack)
1045
		if err != nil {
1046
			return err
1047
		}
1048
	}
1049

1050
	return nil
1051
}
1052

1053
func (b *Broker) responseReceiver() {
1054
	var dead error
1055

1056
	for response := range b.responses {
1057
		if dead != nil {
1058
			// This was previously incremented in send() and
1059
			// we are not calling updateIncomingCommunicationMetrics()
1060
			b.addRequestInFlightMetrics(-1)
1061
			response.handle(nil, dead)
1062
			continue
1063
		}
1064

1065
		headerLength := getHeaderLength(response.headerVersion)
1066
		header := make([]byte, headerLength)
1067

1068
		bytesReadHeader, err := b.readFull(header)
1069
		requestLatency := time.Since(response.requestTime)
1070
		if err != nil {
1071
			b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
1072
			dead = err
1073
			response.handle(nil, err)
1074
			continue
1075
		}
1076

1077
		decodedHeader := responseHeader{}
1078
		err = versionedDecode(header, &decodedHeader, response.headerVersion)
1079
		if err != nil {
1080
			b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
1081
			dead = err
1082
			response.handle(nil, err)
1083
			continue
1084
		}
1085
		if decodedHeader.correlationID != response.correlationID {
1086
			b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
1087
			// TODO if decoded ID < cur ID, discard until we catch up
1088
			// TODO if decoded ID > cur ID, save it so when cur ID catches up we have a response
1089
			dead = PacketDecodingError{fmt.Sprintf("correlation ID didn't match, wanted %d, got %d", response.correlationID, decodedHeader.correlationID)}
1090
			response.handle(nil, dead)
1091
			continue
1092
		}
1093

1094
		buf := make([]byte, decodedHeader.length-int32(headerLength)+4)
1095
		bytesReadBody, err := b.readFull(buf)
1096
		b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency)
1097
		if err != nil {
1098
			dead = err
1099
			response.handle(nil, err)
1100
			continue
1101
		}
1102

1103
		response.handle(buf, nil)
1104
	}
1105
	close(b.done)
1106
}
1107

1108
func getHeaderLength(headerVersion int16) int8 {
1109
	if headerVersion < 1 {
1110
		return 8
1111
	} else {
1112
		// header contains additional tagged field length (0), we don't support actual tags yet.
1113
		return 9
1114
	}
1115
}
1116

1117
func (b *Broker) authenticateViaSASL() error {
1118
	switch b.conf.Net.SASL.Mechanism {
1119
	case SASLTypeOAuth:
1120
		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
1121
	case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
1122
		return b.sendAndReceiveSASLSCRAM()
1123
	case SASLTypeGSSAPI:
1124
		return b.sendAndReceiveKerberos()
1125
	default:
1126
		return b.sendAndReceiveSASLPlainAuth()
1127
	}
1128
}
1129

1130
func (b *Broker) sendAndReceiveKerberos() error {
1131
	b.kerberosAuthenticator.Config = &b.conf.Net.SASL.GSSAPI
1132
	if b.kerberosAuthenticator.NewKerberosClientFunc == nil {
1133
		b.kerberosAuthenticator.NewKerberosClientFunc = NewKerberosClient
1134
	}
1135
	return b.kerberosAuthenticator.Authorize(b)
1136
}
1137

1138
func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
1139
	rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
1140

1141
	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
1142
	buf, err := encode(req, b.conf.MetricRegistry)
1143
	if err != nil {
1144
		return err
1145
	}
1146

1147
	requestTime := time.Now()
1148
	// Will be decremented in updateIncomingCommunicationMetrics (except error)
1149
	b.addRequestInFlightMetrics(1)
1150
	bytes, err := b.write(buf)
1151
	b.updateOutgoingCommunicationMetrics(bytes)
1152
	if err != nil {
1153
		b.addRequestInFlightMetrics(-1)
1154
		Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error())
1155
		return err
1156
	}
1157
	b.correlationID++
1158

1159
	header := make([]byte, 8) // response header
1160
	_, err = b.readFull(header)
1161
	if err != nil {
1162
		b.addRequestInFlightMetrics(-1)
1163
		Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
1164
		return err
1165
	}
1166

1167
	length := binary.BigEndian.Uint32(header[:4])
1168
	payload := make([]byte, length-4)
1169
	n, err := b.readFull(payload)
1170
	if err != nil {
1171
		b.addRequestInFlightMetrics(-1)
1172
		Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
1173
		return err
1174
	}
1175

1176
	b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
1177
	res := &SaslHandshakeResponse{}
1178

1179
	err = versionedDecode(payload, res, 0)
1180
	if err != nil {
1181
		Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
1182
		return err
1183
	}
1184

1185
	if !errors.Is(res.Err, ErrNoError) {
1186
		Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
1187
		return res.Err
1188
	}
1189

1190
	DebugLogger.Print("Completed pre-auth SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
1191
	return nil
1192
}
1193

1194
// Kafka 0.10.x supported SASL PLAIN/Kerberos via KAFKA-3149 (KIP-43).
1195
// Kafka 1.x.x onward added a SaslAuthenticate request/response message which
1196
// wraps the SASL flow in the Kafka protocol, which allows for returning
1197
// meaningful errors on authentication failure.
1198
//
1199
// In SASL Plain, Kafka expects the auth header to be in the following format
1200
// Message format (from https://tools.ietf.org/html/rfc4616):
1201
//
1202
//   message   = [authzid] UTF8NUL authcid UTF8NUL passwd
1203
//   authcid   = 1*SAFE ; MUST accept up to 255 octets
1204
//   authzid   = 1*SAFE ; MUST accept up to 255 octets
1205
//   passwd    = 1*SAFE ; MUST accept up to 255 octets
1206
//   UTF8NUL   = %x00 ; UTF-8 encoded NUL character
1207
//
1208
//   SAFE      = UTF1 / UTF2 / UTF3 / UTF4
1209
//                  ;; any UTF-8 encoded Unicode character except NUL
1210
//
1211
// With SASL v0 handshake and auth then:
1212
// When credentials are valid, Kafka returns a 4 byte array of null characters.
1213
// When credentials are invalid, Kafka closes the connection.
1214
//
1215
// With SASL v1 handshake and auth then:
1216
// When credentials are invalid, Kafka replies with a SaslAuthenticate response
1217
// containing an error code and message detailing the authentication failure.
1218
func (b *Broker) sendAndReceiveSASLPlainAuth() error {
1219
	// default to V0 to allow for backward compatibility when SASL is enabled
1220
	// but not the handshake
1221
	if b.conf.Net.SASL.Handshake {
1222
		handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, b.conf.Net.SASL.Version)
1223
		if handshakeErr != nil {
1224
			Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
1225
			return handshakeErr
1226
		}
1227
	}
1228

1229
	if b.conf.Net.SASL.Version == SASLHandshakeV1 {
1230
		return b.sendAndReceiveV1SASLPlainAuth()
1231
	}
1232
	return b.sendAndReceiveV0SASLPlainAuth()
1233
}
1234

1235
// sendAndReceiveV0SASLPlainAuth flows the v0 sasl auth NOT wrapped in the kafka protocol
1236
func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
1237
	length := len(b.conf.Net.SASL.AuthIdentity) + 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
1238
	authBytes := make([]byte, length+4) // 4 byte length header + auth data
1239
	binary.BigEndian.PutUint32(authBytes, uint32(length))
1240
	copy(authBytes[4:], b.conf.Net.SASL.AuthIdentity+"\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password)
1241

1242
	requestTime := time.Now()
1243
	// Will be decremented in updateIncomingCommunicationMetrics (except error)
1244
	b.addRequestInFlightMetrics(1)
1245
	bytesWritten, err := b.write(authBytes)
1246
	b.updateOutgoingCommunicationMetrics(bytesWritten)
1247
	if err != nil {
1248
		b.addRequestInFlightMetrics(-1)
1249
		Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
1250
		return err
1251
	}
1252

1253
	header := make([]byte, 4)
1254
	n, err := b.readFull(header)
1255
	b.updateIncomingCommunicationMetrics(n, time.Since(requestTime))
1256
	// If the credentials are valid, we would get a 4 byte response filled with null characters.
1257
	// Otherwise, the broker closes the connection and we get an EOF
1258
	if err != nil {
1259
		Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1260
		return err
1261
	}
1262

1263
	DebugLogger.Printf("SASL authentication successful with broker %s:%v - %v\n", b.addr, n, header)
1264
	return nil
1265
}
1266

1267
// sendAndReceiveV1SASLPlainAuth flows the v1 sasl authentication using the kafka protocol
1268
func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
1269
	correlationID := b.correlationID
1270

1271
	requestTime := time.Now()
1272

1273
	// Will be decremented in updateIncomingCommunicationMetrics (except error)
1274
	b.addRequestInFlightMetrics(1)
1275
	bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID)
1276
	b.updateOutgoingCommunicationMetrics(bytesWritten)
1277

1278
	if err != nil {
1279
		b.addRequestInFlightMetrics(-1)
1280
		Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
1281
		return err
1282
	}
1283

1284
	b.correlationID++
1285

1286
	res := &SaslAuthenticateResponse{}
1287
	bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
1288
	b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))
1289

1290
	// With v1 sasl we get an error message set in the response we can return
1291
	if err != nil {
1292
		Logger.Printf(
1293
			"Error returned from broker %s during SASL authentication: %v\n",
1294
			b.addr, err.Error())
1295
		return err
1296
	}
1297

1298
	return nil
1299
}
1300

1301
func currentUnixMilli() int64 {
1302
	return time.Now().UnixNano() / int64(time.Millisecond)
1303
}
1304

1305
// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
1306
// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
1307
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
1308
	if err := b.sendAndReceiveSASLHandshake(SASLTypeOAuth, SASLHandshakeV1); err != nil {
1309
		return err
1310
	}
1311

1312
	token, err := provider.Token()
1313
	if err != nil {
1314
		return err
1315
	}
1316

1317
	message, err := buildClientFirstMessage(token)
1318
	if err != nil {
1319
		return err
1320
	}
1321

1322
	challenged, err := b.sendClientMessage(message)
1323
	if err != nil {
1324
		return err
1325
	}
1326

1327
	if challenged {
1328
		// Abort the token exchange. The broker returns the failure code.
1329
		_, err = b.sendClientMessage([]byte(`\x01`))
1330
	}
1331

1332
	return err
1333
}
1334

1335
// sendClientMessage sends a SASL/OAUTHBEARER client message and returns true
1336
// if the broker responds with a challenge, in which case the token is
1337
// rejected.
1338
func (b *Broker) sendClientMessage(message []byte) (bool, error) {
1339
	requestTime := time.Now()
1340
	// Will be decremented in updateIncomingCommunicationMetrics (except error)
1341
	b.addRequestInFlightMetrics(1)
1342
	correlationID := b.correlationID
1343

1344
	bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
1345
	b.updateOutgoingCommunicationMetrics(bytesWritten)
1346
	if err != nil {
1347
		b.addRequestInFlightMetrics(-1)
1348
		return false, err
1349
	}
1350

1351
	b.correlationID++
1352

1353
	res := &SaslAuthenticateResponse{}
1354
	bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
1355

1356
	requestLatency := time.Since(requestTime)
1357
	b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
1358

1359
	isChallenge := len(res.SaslAuthBytes) > 0
1360

1361
	if isChallenge && err != nil {
1362
		Logger.Printf("Broker rejected authentication token: %s", res.SaslAuthBytes)
1363
	}
1364

1365
	return isChallenge, err
1366
}
1367

1368
func (b *Broker) sendAndReceiveSASLSCRAM() error {
1369
	if b.conf.Net.SASL.Version == SASLHandshakeV0 {
1370
		return b.sendAndReceiveSASLSCRAMv0()
1371
	}
1372
	return b.sendAndReceiveSASLSCRAMv1()
1373
}
1374

1375
func (b *Broker) sendAndReceiveSASLSCRAMv0() error {
1376
	if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV0); err != nil {
1377
		return err
1378
	}
1379

1380
	scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
1381
	if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
1382
		return fmt.Errorf("failed to start SCRAM exchange with the server: %w", err)
1383
	}
1384

1385
	msg, err := scramClient.Step("")
1386
	if err != nil {
1387
		return fmt.Errorf("failed to advance the SCRAM exchange: %w", err)
1388
	}
1389

1390
	for !scramClient.Done() {
1391
		requestTime := time.Now()
1392
		// Will be decremented in updateIncomingCommunicationMetrics (except error)
1393
		b.addRequestInFlightMetrics(1)
1394
		length := len(msg)
1395
		authBytes := make([]byte, length+4) //4 byte length header + auth data
1396
		binary.BigEndian.PutUint32(authBytes, uint32(length))
1397
		copy(authBytes[4:], []byte(msg))
1398
		_, err := b.write(authBytes)
1399
		b.updateOutgoingCommunicationMetrics(length + 4)
1400
		if err != nil {
1401
			b.addRequestInFlightMetrics(-1)
1402
			Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
1403
			return err
1404
		}
1405
		b.correlationID++
1406
		header := make([]byte, 4)
1407
		_, err = b.readFull(header)
1408
		if err != nil {
1409
			b.addRequestInFlightMetrics(-1)
1410
			Logger.Printf("Failed to read response header while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1411
			return err
1412
		}
1413
		payload := make([]byte, int32(binary.BigEndian.Uint32(header)))
1414
		n, err := b.readFull(payload)
1415
		if err != nil {
1416
			b.addRequestInFlightMetrics(-1)
1417
			Logger.Printf("Failed to read response payload while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1418
			return err
1419
		}
1420
		b.updateIncomingCommunicationMetrics(n+4, time.Since(requestTime))
1421
		msg, err = scramClient.Step(string(payload))
1422
		if err != nil {
1423
			Logger.Println("SASL authentication failed", err)
1424
			return err
1425
		}
1426
	}
1427

1428
	DebugLogger.Println("SASL authentication succeeded")
1429
	return nil
1430
}
1431

1432
func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
1433
	if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
1434
		return err
1435
	}
1436

1437
	scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
1438
	if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
1439
		return fmt.Errorf("failed to start SCRAM exchange with the server: %w", err)
1440
	}
1441

1442
	msg, err := scramClient.Step("")
1443
	if err != nil {
1444
		return fmt.Errorf("failed to advance the SCRAM exchange: %w", err)
1445
	}
1446

1447
	for !scramClient.Done() {
1448
		requestTime := time.Now()
1449
		// Will be decremented in updateIncomingCommunicationMetrics (except error)
1450
		b.addRequestInFlightMetrics(1)
1451
		correlationID := b.correlationID
1452
		bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
1453
		b.updateOutgoingCommunicationMetrics(bytesWritten)
1454
		if err != nil {
1455
			b.addRequestInFlightMetrics(-1)
1456
			Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
1457
			return err
1458
		}
1459

1460
		b.correlationID++
1461
		challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
1462
		if err != nil {
1463
			b.addRequestInFlightMetrics(-1)
1464
			Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1465
			return err
1466
		}
1467

1468
		b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
1469
		msg, err = scramClient.Step(string(challenge))
1470
		if err != nil {
1471
			Logger.Println("SASL authentication failed", err)
1472
			return err
1473
		}
1474
	}
1475

1476
	DebugLogger.Println("SASL authentication succeeded")
1477
	return nil
1478
}
1479

1480
func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
1481
	rb := b.createSaslAuthenticateRequest(msg)
1482
	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
1483
	buf, err := encode(req, b.conf.MetricRegistry)
1484
	if err != nil {
1485
		return 0, err
1486
	}
1487

1488
	return b.write(buf)
1489
}
1490

1491
func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest {
1492
	authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg}
1493
	if b.conf.Version.IsAtLeast(V2_2_0_0) {
1494
		authenticateRequest.Version = 1
1495
	}
1496

1497
	return &authenticateRequest
1498
}
1499

1500
func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
1501
	buf := make([]byte, responseLengthSize+correlationIDSize)
1502
	_, err := b.readFull(buf)
1503
	if err != nil {
1504
		return nil, err
1505
	}
1506

1507
	header := responseHeader{}
1508
	err = versionedDecode(buf, &header, 0)
1509
	if err != nil {
1510
		return nil, err
1511
	}
1512

1513
	if header.correlationID != correlationID {
1514
		return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
1515
	}
1516

1517
	buf = make([]byte, header.length-correlationIDSize)
1518
	_, err = b.readFull(buf)
1519
	if err != nil {
1520
		return nil, err
1521
	}
1522

1523
	res := &SaslAuthenticateResponse{}
1524
	if err := versionedDecode(buf, res, 0); err != nil {
1525
		return nil, err
1526
	}
1527
	if !errors.Is(res.Err, ErrNoError) {
1528
		return nil, res.Err
1529
	}
1530
	return res.SaslAuthBytes, nil
1531
}
1532

1533
// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
1534
// https://tools.ietf.org/html/rfc7628
1535
func buildClientFirstMessage(token *AccessToken) ([]byte, error) {
1536
	var ext string
1537

1538
	if token.Extensions != nil && len(token.Extensions) > 0 {
1539
		if _, ok := token.Extensions[SASLExtKeyAuth]; ok {
1540
			return []byte{}, fmt.Errorf("the extension `%s` is invalid", SASLExtKeyAuth)
1541
		}
1542
		ext = "\x01" + mapToString(token.Extensions, "=", "\x01")
1543
	}
1544

1545
	resp := []byte(fmt.Sprintf("n,,\x01auth=Bearer %s%s\x01\x01", token.Token, ext))
1546

1547
	return resp, nil
1548
}
1549

1550
// mapToString returns a list of key-value pairs ordered by key.
1551
// keyValSep separates the key from the value. elemSep separates each pair.
1552
func mapToString(extensions map[string]string, keyValSep string, elemSep string) string {
1553
	buf := make([]string, 0, len(extensions))
1554

1555
	for k, v := range extensions {
1556
		buf = append(buf, k+keyValSep+v)
1557
	}
1558

1559
	sort.Strings(buf)
1560

1561
	return strings.Join(buf, elemSep)
1562
}
1563

1564
func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) {
1565
	authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
1566
	rb := b.createSaslAuthenticateRequest(authBytes)
1567
	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
1568
	buf, err := encode(req, b.conf.MetricRegistry)
1569
	if err != nil {
1570
		return 0, rb.Version, err
1571
	}
1572

1573
	write, err := b.write(buf)
1574
	return write, rb.Version, err
1575
}
1576

1577
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) {
1578
	rb := b.createSaslAuthenticateRequest(initialResp)
1579

1580
	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
1581

1582
	buf, err := encode(req, b.conf.MetricRegistry)
1583
	if err != nil {
1584
		return 0, rb.version(), err
1585
	}
1586

1587
	write, err := b.write(buf)
1588
	return write, rb.version(), err
1589
}
1590

1591
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) {
1592
	buf := make([]byte, responseLengthSize+correlationIDSize)
1593
	bytesRead, err := b.readFull(buf)
1594
	if err != nil {
1595
		return bytesRead, err
1596
	}
1597

1598
	header := responseHeader{}
1599
	err = versionedDecode(buf, &header, 0)
1600
	if err != nil {
1601
		return bytesRead, err
1602
	}
1603

1604
	if header.correlationID != correlationID {
1605
		return bytesRead, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
1606
	}
1607

1608
	buf = make([]byte, header.length-correlationIDSize)
1609
	c, err := b.readFull(buf)
1610
	bytesRead += c
1611
	if err != nil {
1612
		return bytesRead, err
1613
	}
1614

1615
	if err := versionedDecode(buf, res, resVersion); err != nil {
1616
		return bytesRead, err
1617
	}
1618

1619
	if !errors.Is(res.Err, ErrNoError) {
1620
		var err error = res.Err
1621
		if res.ErrorMessage != nil {
1622
			err = Wrap(res.Err, errors.New(*res.ErrorMessage))
1623
		}
1624
		return bytesRead, err
1625
	}
1626

1627
	if res.SessionLifetimeMs > 0 {
1628
		// Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes
1629
		// pick a random percentage between 85% and 95% for session re-authentication
1630
		positiveSessionLifetimeMs := res.SessionLifetimeMs
1631
		authenticationEndMs := currentUnixMilli()
1632
		pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85
1633
		pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10
1634
		pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously
1635
		sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse)
1636
		DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse)
1637
		b.clientSessionReauthenticationTimeMs = authenticationEndMs + sessionLifetimeMsToUse
1638
	} else {
1639
		b.clientSessionReauthenticationTimeMs = 0
1640
	}
1641

1642
	return bytesRead, nil
1643
}
1644

1645
func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
1646
	b.updateRequestLatencyAndInFlightMetrics(requestLatency)
1647
	b.responseRate.Mark(1)
1648

1649
	if b.brokerResponseRate != nil {
1650
		b.brokerResponseRate.Mark(1)
1651
	}
1652

1653
	responseSize := int64(bytes)
1654
	b.incomingByteRate.Mark(responseSize)
1655
	if b.brokerIncomingByteRate != nil {
1656
		b.brokerIncomingByteRate.Mark(responseSize)
1657
	}
1658

1659
	b.responseSize.Update(responseSize)
1660
	if b.brokerResponseSize != nil {
1661
		b.brokerResponseSize.Update(responseSize)
1662
	}
1663
}
1664

1665
func (b *Broker) updateRequestLatencyAndInFlightMetrics(requestLatency time.Duration) {
1666
	requestLatencyInMs := int64(requestLatency / time.Millisecond)
1667
	b.requestLatency.Update(requestLatencyInMs)
1668

1669
	if b.brokerRequestLatency != nil {
1670
		b.brokerRequestLatency.Update(requestLatencyInMs)
1671
	}
1672

1673
	b.addRequestInFlightMetrics(-1)
1674
}
1675

1676
func (b *Broker) addRequestInFlightMetrics(i int64) {
1677
	b.requestsInFlight.Inc(i)
1678
	if b.brokerRequestsInFlight != nil {
1679
		b.brokerRequestsInFlight.Inc(i)
1680
	}
1681
}
1682

1683
func (b *Broker) updateOutgoingCommunicationMetrics(bytes int) {
1684
	b.requestRate.Mark(1)
1685
	if b.brokerRequestRate != nil {
1686
		b.brokerRequestRate.Mark(1)
1687
	}
1688

1689
	requestSize := int64(bytes)
1690
	b.outgoingByteRate.Mark(requestSize)
1691
	if b.brokerOutgoingByteRate != nil {
1692
		b.brokerOutgoingByteRate.Mark(requestSize)
1693
	}
1694

1695
	b.requestSize.Update(requestSize)
1696
	if b.brokerRequestSize != nil {
1697
		b.brokerRequestSize.Update(requestSize)
1698
	}
1699
}
1700

1701
func (b *Broker) updateThrottleMetric(throttleTime time.Duration) {
1702
	if throttleTime != time.Duration(0) {
1703
		DebugLogger.Printf(
1704
			"producer/broker/%d ProduceResponse throttled %v\n",
1705
			b.ID(), throttleTime)
1706
		if b.brokerThrottleTime != nil {
1707
			throttleTimeInMs := int64(throttleTime / time.Millisecond)
1708
			b.brokerThrottleTime.Update(throttleTimeInMs)
1709
		}
1710
	}
1711
}
1712

1713
func (b *Broker) registerMetrics() {
1714
	b.brokerIncomingByteRate = b.registerMeter("incoming-byte-rate")
1715
	b.brokerRequestRate = b.registerMeter("request-rate")
1716
	b.brokerRequestSize = b.registerHistogram("request-size")
1717
	b.brokerRequestLatency = b.registerHistogram("request-latency-in-ms")
1718
	b.brokerOutgoingByteRate = b.registerMeter("outgoing-byte-rate")
1719
	b.brokerResponseRate = b.registerMeter("response-rate")
1720
	b.brokerResponseSize = b.registerHistogram("response-size")
1721
	b.brokerRequestsInFlight = b.registerCounter("requests-in-flight")
1722
	b.brokerThrottleTime = b.registerHistogram("throttle-time-in-ms")
1723
}
1724

1725
func (b *Broker) unregisterMetrics() {
1726
	for name := range b.registeredMetrics {
1727
		b.conf.MetricRegistry.Unregister(name)
1728
	}
1729
	b.registeredMetrics = nil
1730
}
1731

1732
func (b *Broker) registerMeter(name string) metrics.Meter {
1733
	nameForBroker := getMetricNameForBroker(name, b)
1734
	if b.registeredMetrics == nil {
1735
		b.registeredMetrics = map[string]struct{}{}
1736
	}
1737
	b.registeredMetrics[nameForBroker] = struct{}{}
1738
	return metrics.GetOrRegisterMeter(nameForBroker, b.conf.MetricRegistry)
1739
}
1740

1741
func (b *Broker) registerHistogram(name string) metrics.Histogram {
1742
	nameForBroker := getMetricNameForBroker(name, b)
1743
	if b.registeredMetrics == nil {
1744
		b.registeredMetrics = map[string]struct{}{}
1745
	}
1746
	b.registeredMetrics[nameForBroker] = struct{}{}
1747
	return getOrRegisterHistogram(nameForBroker, b.conf.MetricRegistry)
1748
}
1749

1750
func (b *Broker) registerCounter(name string) metrics.Counter {
1751
	nameForBroker := getMetricNameForBroker(name, b)
1752
	if b.registeredMetrics == nil {
1753
		b.registeredMetrics = map[string]struct{}{}
1754
	}
1755
	b.registeredMetrics[nameForBroker] = struct{}{}
1756
	return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry)
1757
}
1758

1759
func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config {
1760
	if cfg == nil {
1761
		cfg = &tls.Config{
1762
			MinVersion: tls.VersionTLS12,
1763
		}
1764
	}
1765
	if cfg.ServerName != "" {
1766
		return cfg
1767
	}
1768

1769
	c := cfg.Clone()
1770
	sn, _, err := net.SplitHostPort(addr)
1771
	if err != nil {
1772
		Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err))
1773
	}
1774
	c.ServerName = sn
1775
	return c
1776
}
1777

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.