1
// This Source Code Form is subject to the terms of the Mozilla Public
2
// License, v. 2.0. If a copy of the MPL was not distributed with this
3
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
14
"github.com/siderolabs/grpc-proxy/proxy"
15
"github.com/siderolabs/net"
16
"google.golang.org/grpc"
17
"google.golang.org/grpc/backoff"
18
"google.golang.org/grpc/connectivity"
19
"google.golang.org/grpc/credentials"
20
"google.golang.org/grpc/metadata"
21
"google.golang.org/grpc/status"
22
"google.golang.org/protobuf/encoding/protowire"
24
"github.com/siderolabs/talos/pkg/grpc/middleware/authz"
25
"github.com/siderolabs/talos/pkg/machinery/api/common"
26
"github.com/siderolabs/talos/pkg/machinery/constants"
27
"github.com/siderolabs/talos/pkg/machinery/proto"
30
// GracefulShutdownTimeout is the timeout for graceful shutdown of the backend connection.
32
// Talos has a few long-running API calls, so we need to give the backend some time to finish them.
34
// The connection will enter IDLE time after GracefulShutdownTimeout/2, if no RPC is running.
35
const GracefulShutdownTimeout = 30 * time.Minute
37
var _ proxy.Backend = (*APID)(nil)
39
// APID backend performs proxying to another apid instance.
41
// Backend authenticates itself using given grpc credentials.
45
tlsConfigProvider func() (*tls.Config, error)
51
// NewAPID creates new instance of APID backend.
52
func NewAPID(target string, tlsConfigProvider func() (*tls.Config, error)) (*APID, error) {
53
// perform very basic validation on target, trying to weed out empty addresses or addresses with the port appended
54
if target == "" || net.AddressContainsPort(target) {
55
return nil, fmt.Errorf("invalid target %q", target)
60
tlsConfigProvider: tlsConfigProvider,
64
func (a *APID) String() string {
68
// GetConnection returns a grpc connection to the backend.
69
func (a *APID) GetConnection(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) {
70
md, _ := metadata.FromIncomingContext(ctx)
73
authz.SetMetadata(md, authz.GetRoles(ctx))
75
if authority := md[":authority"]; len(authority) > 0 {
76
md.Set("proxyfrom", authority...)
78
md.Set("proxyfrom", "unknown")
81
delete(md, ":authority")
85
outCtx := metadata.NewOutgoingContext(ctx, md)
91
return outCtx, a.conn, nil
94
tlsConfig, err := a.tlsConfigProvider()
96
return outCtx, nil, err
99
// override max delay to avoid excessive backoff when the another node is unavailable (e.g. rebooted),
100
// and apid used as an endpoint considers another node to be down for longer than expected.
102
// default max delay is 2 minutes, which is too long for our use case.
103
backoffConfig := backoff.DefaultConfig
104
backoffConfig.MaxDelay = 15 * time.Second
106
a.conn, err = grpc.NewClient(
107
fmt.Sprintf("%s:%d", net.FormatAddress(a.target), constants.ApidPort),
108
grpc.WithInitialWindowSize(65535*32),
109
grpc.WithInitialConnWindowSize(65535*16),
110
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
111
grpc.WithIdleTimeout(GracefulShutdownTimeout/2), // use half of the shutdown timeout as idle timeout
112
grpc.WithConnectParams(grpc.ConnectParams{
113
Backoff: backoffConfig,
114
// not published as a constant in gRPC library
115
// see: https://github.com/grpc/grpc-go/blob/d5dee5fdbdeb52f6ea10b37b2cc7ce37814642d7/clientconn.go#L55-L56
116
MinConnectTimeout: 20 * time.Second,
118
grpc.WithDefaultCallOptions(
119
grpc.MaxCallRecvMsgSize(constants.GRPCMaxMessageSize),
121
grpc.WithCodec(proxy.Codec()), //nolint:staticcheck
122
grpc.WithSharedWriteBuffer(true),
125
return outCtx, a.conn, err
128
// AppendInfo is called to enhance response from the backend with additional data.
130
// AppendInfo enhances upstream response with node metadata (target).
132
// This method depends on grpc protobuf response structure, each response should
135
// message SomeResponse {
136
// repeated SomeReply messages = 1; // please note field ID == 1
139
// message SomeReply {
140
// common.Metadata metadata = 1;
141
// <other fields go here ...>
144
// As 'SomeReply' is repeated in 'SomeResponse', if we concatenate protobuf representation
145
// of several 'SomeResponse' messages, we still get valid 'SomeResponse' representation but with more
146
// entries (feature of protobuf binary representation).
148
// If we look at binary representation of any unary 'SomeResponse' message, it will always contain one
149
// protobuf field with field ID 1 (see above) and type 2 (embedded message SomeReply is encoded
150
// as string with length). So if we want to add fields to 'SomeReply', we can simply read field
151
// header, adjust length for new 'SomeReply' representation, and prepend new field header.
153
// At the same time, we can add 'common.Metadata' structure to 'SomeReply' by simply
154
// appending or prepending 'common.Metadata' as a single field. This requires 'metadata'
155
// field to be not defined in original response. (This is due to the fact that protobuf message
156
// representation is concatenation of each field representation).
158
// To build only single field (Metadata) we use helper message which contains exactly this
159
// field with same field ID as in every other 'SomeReply':
162
// common.Metadata metadata = 1;
165
// As streaming replies are not wrapped into 'SomeResponse' with 'repeated', handling is simpler: we just
166
// need to append Empty with details.
168
// So AppendInfo does the following: validates that response contains field ID 1 encoded as string,
169
// cuts field header, rest is representation of some reply. Marshal 'Empty' as protobuf,
170
// which builds 'common.Metadata' field, append it to original response message, build new header
171
// for new length of some response, and add back new field header.
172
func (a *APID) AppendInfo(streaming bool, resp []byte) ([]byte, error) {
173
payload, err := proto.Marshal(&common.Empty{
174
Metadata: &common.Metadata{
180
return append(resp, payload...), err
184
metadataField = 1 // field number in proto definition for repeated response
185
metadataType = 2 // "string" for embedded messages
188
// decode protobuf embedded header
190
typ, n1 := protowire.ConsumeVarint(resp)
192
return nil, protowire.ParseError(n1)
195
_, n2 := protowire.ConsumeVarint(resp[n1:]) // length
197
return nil, protowire.ParseError(n2)
200
if typ != (metadataField<<3)|metadataType {
201
return nil, fmt.Errorf("unexpected message format: %d", typ)
204
if n1+n2 > len(resp) {
205
return nil, fmt.Errorf("unexpected message size: %d", len(resp))
208
// cut off embedded message header
210
// build new embedded message header
211
prefix := protowire.AppendVarint(
212
protowire.AppendVarint(nil, (metadataField<<3)|metadataType),
213
uint64(len(resp)+len(payload)),
215
resp = append(prefix, resp...)
217
return append(resp, payload...), err
220
// BuildError is called to convert error from upstream into response field.
222
// BuildError converts upstream error into message from upstream, so that multiple
223
// successful and failure responses might be returned.
225
// This simply relies on the fact that any response contains 'Empty' message.
226
// So if 'Empty' is unmarshalled into any other reply message, all the fields
227
// are undefined but 'Metadata':
230
// common.Metadata metadata = 1;
233
// message EmptyResponse {
234
// repeated Empty messages = 1;
237
// Streaming responses are not wrapped into Empty, so we simply marshall EmptyResponse
239
func (a *APID) BuildError(streaming bool, err error) ([]byte, error) {
240
var resp proto.Message = &common.Empty{
241
Metadata: &common.Metadata{
244
Status: status.Convert(err).Proto(),
249
resp = &common.EmptyResponse{
250
Messages: []*common.Empty{
251
resp.(*common.Empty),
256
return proto.Marshal(resp)
260
func (a *APID) Close() {
265
gracefulGRPCClose(a.conn, GracefulShutdownTimeout)
270
func gracefulGRPCClose(conn *grpc.ClientConn, timeout time.Duration) {
271
// close the client connection in the background, tries to avoid closing the connection
272
// if the connection is in the middle of a call (e.g. streaming API)
274
// see https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md for details on connection states
276
ctx, cancel := context.WithTimeout(context.Background(), timeout)
279
for ctx.Err() == nil {
280
switch state := conn.GetState(); state { //nolint:exhaustive
281
case connectivity.Idle,
282
connectivity.Shutdown,
283
connectivity.TransientFailure:
284
// close immediately, connection is not used
285
conn.Close() //nolint:errcheck
289
// wait for state change of the connection
290
conn.WaitForStateChange(ctx, state)
294
// close anyways on timeout
295
conn.Close() //nolint:errcheck