17
"github.com/siderolabs/gen/xslices"
18
"github.com/siderolabs/go-api-signature/pkg/client/interceptor"
19
"github.com/siderolabs/go-api-signature/pkg/pgp/client"
20
"google.golang.org/grpc"
21
"google.golang.org/grpc/credentials"
23
clientconfig "github.com/siderolabs/talos/pkg/machinery/client/config"
24
"github.com/siderolabs/talos/pkg/machinery/client/resolver"
25
"github.com/siderolabs/talos/pkg/machinery/constants"
29
func (c *Client) Conn() *grpc.ClientConn {
30
return c.conn.ClientConn
34
func (c *Client) getConn(opts ...grpc.DialOption) (*grpcConnectionWrapper, error) {
35
endpoints := c.GetEndpoints()
37
target := c.getTarget(
38
resolver.EnsureEndpointsHavePorts(
39
reduceURLsToAddresses(endpoints),
43
dialOpts := []grpc.DialOption{
44
grpc.WithDefaultCallOptions(
47
grpc.MaxCallRecvMsgSize(constants.GRPCMaxMessageSize),
49
grpc.WithSharedWriteBuffer(true),
51
dialOpts = append(dialOpts, c.options.grpcDialOptions...)
52
dialOpts = append(dialOpts, opts...)
54
if c.options.unixSocketPath != "" {
55
conn, err := grpc.NewClient(target, dialOpts...)
57
return newGRPCConnectionWrapper(c.GetClusterName(), conn), err
60
tlsConfig := c.options.tlsConfig
63
return c.makeConnection(target, credentials.NewTLS(tlsConfig), dialOpts)
66
if err := c.resolveConfigContext(); err != nil {
67
return nil, fmt.Errorf("failed to resolve configuration context: %w", err)
70
basicAuth := c.options.configContext.Auth.Basic
72
dialOpts = append(dialOpts, WithGRPCBasicAuth(basicAuth.Username, basicAuth.Password))
75
sideroV1 := c.options.configContext.Auth.SideroV1
77
var contextName string
79
if c.options.config != nil {
80
contextName = c.options.config.Context
83
if c.options.contextOverrideSet {
84
contextName = c.options.contextOverride
87
authInterceptor := interceptor.New(interceptor.Options{
88
UserKeyProvider: client.NewKeyProvider("talos/keys"),
89
ContextName: contextName,
90
Identity: sideroV1.Identity,
94
dialOpts = append(dialOpts,
95
grpc.WithUnaryInterceptor(authInterceptor.Unary()),
96
grpc.WithStreamInterceptor(authInterceptor.Stream()),
100
creds, err := buildCredentials(c.options.configContext, endpoints)
105
return c.makeConnection(target, creds, dialOpts)
108
func buildTLSConfig(configContext *clientconfig.Context) (*tls.Config, error) {
109
tlsConfig := &tls.Config{}
111
caBytes, err := getCA(configContext)
113
return nil, fmt.Errorf("failed to get CA: %w", err)
116
if len(caBytes) > 0 {
117
tlsConfig.RootCAs = x509.NewCertPool()
119
if ok := tlsConfig.RootCAs.AppendCertsFromPEM(caBytes); !ok {
120
return nil, errors.New("failed to append CA certificate to RootCAs pool")
124
crt, err := CertificateFromConfigContext(configContext)
126
return nil, fmt.Errorf("failed to acquire credentials: %w", err)
130
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
131
tlsConfig.Certificates = append(tlsConfig.Certificates, *crt)
134
return tlsConfig, nil
137
func (c *Client) makeConnection(target string, creds credentials.TransportCredentials, dialOpts []grpc.DialOption) (*grpcConnectionWrapper, error) {
138
dialOpts = append(dialOpts,
139
grpc.WithTransportCredentials(creds),
140
grpc.WithInitialWindowSize(65535*32),
141
grpc.WithInitialConnWindowSize(65535*16))
143
conn, err := grpc.NewClient(target, dialOpts...)
145
return newGRPCConnectionWrapper(c.GetClusterName(), conn), err
148
func (c *Client) getTarget(endpoints []string) string {
150
case c.options.unixSocketPath != "":
151
return fmt.Sprintf("unix:///%s", c.options.unixSocketPath)
152
case len(endpoints) > 1:
153
return fmt.Sprintf("%s:///%s", resolver.RoundRobinResolverScheme, strings.Join(endpoints, ","))
160
return fmt.Sprintf("dns:///%s", endpoints[0])
164
func getCA(context *clientconfig.Context) ([]byte, error) {
165
if context.CA == "" {
169
caBytes, err := base64.StdEncoding.DecodeString(context.CA)
171
return nil, fmt.Errorf("error decoding CA: %w", err)
178
func CertificateFromConfigContext(context *clientconfig.Context) (*tls.Certificate, error) {
179
if context.Crt == "" && context.Key == "" {
183
crtBytes, err := base64.StdEncoding.DecodeString(context.Crt)
185
return nil, fmt.Errorf("error decoding certificate: %w", err)
188
keyBytes, err := base64.StdEncoding.DecodeString(context.Key)
190
return nil, fmt.Errorf("error decoding key: %w", err)
193
crt, err := tls.X509KeyPair(crtBytes, keyBytes)
195
return nil, fmt.Errorf("could not load client key pair: %s", err)
201
func reduceURLsToAddresses(endpoints []string) []string {
202
return xslices.Map(endpoints, func(endpoint string) string {
203
u, err := url.Parse(endpoint)
208
if u.Scheme == "https" && u.Port() == "" {
209
return net.JoinHostPort(u.Hostname(), "443")
214
return net.JoinHostPort(u.Hostname(), u.Port())