talos

Форк
0
/
apid_test.go 
397 строк · 12.0 Кб
1
// This Source Code Form is subject to the terms of the Mozilla Public
2
// License, v. 2.0. If a copy of the MPL was not distributed with this
3
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4

5
package backend_test
6

7
import (
8
	"context"
9
	"crypto/tls"
10
	"errors"
11
	"testing"
12

13
	"github.com/siderolabs/go-pointer"
14
	"github.com/stretchr/testify/assert"
15
	"github.com/stretchr/testify/require"
16
	"github.com/stretchr/testify/suite"
17
	"google.golang.org/grpc/metadata"
18
	protobuf "google.golang.org/protobuf/proto" //nolint:depguard
19
	"google.golang.org/protobuf/reflect/protoreflect"
20
	"google.golang.org/protobuf/types/descriptorpb"
21

22
	"github.com/siderolabs/talos/internal/app/apid/pkg/backend"
23
	"github.com/siderolabs/talos/pkg/grpc/middleware/authz"
24
	"github.com/siderolabs/talos/pkg/machinery/api/cluster"
25
	"github.com/siderolabs/talos/pkg/machinery/api/common"
26
	"github.com/siderolabs/talos/pkg/machinery/api/inspect"
27
	"github.com/siderolabs/talos/pkg/machinery/api/machine"
28
	"github.com/siderolabs/talos/pkg/machinery/api/security"
29
	"github.com/siderolabs/talos/pkg/machinery/api/storage"
30
	"github.com/siderolabs/talos/pkg/machinery/api/time"
31
	"github.com/siderolabs/talos/pkg/machinery/config"
32
	"github.com/siderolabs/talos/pkg/machinery/proto"
33
	"github.com/siderolabs/talos/pkg/machinery/role"
34
	"github.com/siderolabs/talos/pkg/machinery/version"
35
)
36

37
type APIDSuite struct {
38
	suite.Suite
39

40
	b *backend.APID
41
}
42

43
func (suite *APIDSuite) SetupSuite() {
44
	tlsConfigProvider := func() (*tls.Config, error) {
45
		return &tls.Config{}, nil
46
	}
47

48
	var err error
49
	suite.b, err = backend.NewAPID("127.0.0.1", tlsConfigProvider)
50
	suite.Require().NoError(err)
51
}
52

53
func (suite *APIDSuite) TestGetConnection() {
54
	md1 := metadata.New(nil)
55
	md1.Set(":authority", "127.0.0.2")
56
	md1.Set("nodes", "127.0.0.1")
57
	md1.Set("key", "value1", "value2")
58
	ctx1 := metadata.NewIncomingContext(authz.ContextWithRoles(context.Background(), role.MakeSet(role.Admin)), md1)
59

60
	outCtx1, conn1, err1 := suite.b.GetConnection(ctx1, "")
61
	suite.Require().NoError(err1)
62
	suite.Assert().NotNil(conn1)
63
	suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx1))
64

65
	mdOut1, ok1 := metadata.FromOutgoingContext(outCtx1)
66
	suite.Require().True(ok1)
67
	suite.Assert().Equal([]string{"value1", "value2"}, mdOut1.Get("key"))
68
	suite.Assert().Equal([]string{"127.0.0.2"}, mdOut1.Get("proxyfrom"))
69
	suite.Assert().Equal([]string{"os:admin"}, mdOut1.Get("talos-role"))
70

71
	suite.Run(
72
		"Same context", func() {
73
			ctx2 := ctx1
74
			outCtx2, conn2, err2 := suite.b.GetConnection(ctx2, "")
75
			suite.Require().NoError(err2)
76
			suite.Assert().Equal(conn1, conn2) // connection is cached
77
			suite.Assert().Equal(role.MakeSet(role.Admin), authz.GetRoles(outCtx2))
78

79
			mdOut2, ok2 := metadata.FromOutgoingContext(outCtx2)
80
			suite.Require().True(ok2)
81
			suite.Assert().Equal([]string{"value1", "value2"}, mdOut2.Get("key"))
82
			suite.Assert().Equal([]string{"127.0.0.2"}, mdOut2.Get("proxyfrom"))
83
			suite.Assert().Equal([]string{"os:admin"}, mdOut2.Get("talos-role"))
84
		},
85
	)
86

87
	suite.Run(
88
		"Other context", func() {
89
			md3 := metadata.New(nil)
90
			md3.Set(":authority", "127.0.0.2")
91
			md3.Set("nodes", "127.0.0.1")
92
			md3.Set("key", "value3", "value4")
93
			ctx3 := metadata.NewIncomingContext(
94
				authz.ContextWithRoles(context.Background(), role.MakeSet(role.Reader)),
95
				md3,
96
			)
97

98
			outCtx3, conn3, err3 := suite.b.GetConnection(ctx3, "")
99
			suite.Require().NoError(err3)
100
			suite.Assert().Equal(conn1, conn3) // connection is cached
101
			suite.Assert().Equal(role.MakeSet(role.Reader), authz.GetRoles(outCtx3))
102

103
			mdOut3, ok3 := metadata.FromOutgoingContext(outCtx3)
104
			suite.Require().True(ok3)
105
			suite.Assert().Equal([]string{"value3", "value4"}, mdOut3.Get("key"))
106
			suite.Assert().Equal([]string{"127.0.0.2"}, mdOut3.Get("proxyfrom"))
107
			suite.Assert().Equal([]string{"os:reader"}, mdOut3.Get("talos-role"))
108
		},
109
	)
110
}
111

112
func (suite *APIDSuite) TestAppendInfoUnary() {
113
	reply := &common.DataResponse{
114
		Messages: []*common.Data{
115
			{
116
				Bytes: []byte("foobar"),
117
			},
118
		},
119
	}
120

121
	resp, err := proto.Marshal(reply)
122
	suite.Require().NoError(err)
123

124
	newResp, err := suite.b.AppendInfo(false, resp)
125
	suite.Require().NoError(err)
126

127
	var newReply common.DataResponse
128
	err = proto.Unmarshal(newResp, &newReply)
129
	suite.Require().NoError(err)
130

131
	suite.Assert().EqualValues([]byte("foobar"), newReply.Messages[0].Bytes)
132
	suite.Assert().Equal(suite.b.String(), newReply.Messages[0].Metadata.Hostname)
133
	suite.Assert().Empty(newReply.Messages[0].Metadata.Error)
134
}
135

136
func (suite *APIDSuite) TestAppendInfoStreaming() {
137
	response := &common.Data{
138
		Bytes: []byte("foobar"),
139
	}
140

141
	resp, err := proto.Marshal(response)
142
	suite.Require().NoError(err)
143

144
	newResp, err := suite.b.AppendInfo(true, resp)
145
	suite.Require().NoError(err)
146

147
	var newResponse common.Data
148
	err = proto.Unmarshal(newResp, &newResponse)
149
	suite.Require().NoError(err)
150

151
	suite.Assert().EqualValues([]byte("foobar"), newResponse.Bytes)
152
	suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
153
	suite.Assert().Empty(newResponse.Metadata.Error)
154
}
155

156
func (suite *APIDSuite) TestAppendInfoStreamingMetadata() {
157
	// this tests the case when metadata field is appended twice
158
	// to the message, but protobuf merges definitions
159
	response := &common.Data{
160
		Metadata: &common.Metadata{
161
			Error: "something went wrong",
162
		},
163
	}
164

165
	resp, err := proto.Marshal(response)
166
	suite.Require().NoError(err)
167

168
	newResp, err := suite.b.AppendInfo(true, resp)
169
	suite.Require().NoError(err)
170

171
	var newResponse common.Data
172
	err = proto.Unmarshal(newResp, &newResponse)
173
	suite.Require().NoError(err)
174

175
	suite.Assert().Nil(newResponse.Bytes)
176
	suite.Assert().Equal(suite.b.String(), newResponse.Metadata.Hostname)
177
	suite.Assert().Equal("something went wrong", newResponse.Metadata.Error)
178
}
179

180
func (suite *APIDSuite) TestBuildErrorUnary() {
181
	resp, err := suite.b.BuildError(false, errors.New("some error"))
182
	suite.Require().NoError(err)
183

184
	var reply common.DataResponse
185
	err = proto.Unmarshal(resp, &reply)
186
	suite.Require().NoError(err)
187

188
	suite.Assert().Nil(reply.Messages[0].Bytes)
189
	suite.Assert().Equal(suite.b.String(), reply.Messages[0].Metadata.Hostname)
190
	suite.Assert().Equal("some error", reply.Messages[0].Metadata.Error)
191
}
192

193
func (suite *APIDSuite) TestBuildErrorStreaming() {
194
	resp, err := suite.b.BuildError(true, errors.New("some error"))
195
	suite.Require().NoError(err)
196

197
	var response common.Data
198
	err = proto.Unmarshal(resp, &response)
199
	suite.Require().NoError(err)
200

201
	suite.Assert().Nil(response.Bytes)
202
	suite.Assert().Equal(suite.b.String(), response.Metadata.Hostname)
203
	suite.Assert().Equal("some error", response.Metadata.Error)
204
}
205

206
func TestAPIDSuite(t *testing.T) {
207
	suite.Run(t, new(APIDSuite))
208
}
209

210
func TestAPIIdiosyncrasies(t *testing.T) {
211
	for _, services := range []protoreflect.ServiceDescriptors{
212
		common.File_common_common_proto.Services(),
213
		cluster.File_cluster_cluster_proto.Services(),
214
		inspect.File_inspect_inspect_proto.Services(),
215
		machine.File_machine_machine_proto.Services(),
216
		// security.File_security_security_proto.Services() is different
217
		storage.File_storage_storage_proto.Services(),
218
		time.File_time_time_proto.Services(),
219
	} {
220
		for i := range services.Len() {
221
			service := services.Get(i)
222
			methods := service.Methods()
223

224
			for j := range methods.Len() {
225
				method := methods.Get(j)
226

227
				t.Run(
228
					string(method.FullName()), func(t *testing.T) {
229
						response := method.Output()
230
						responseFields := response.Fields()
231

232
						if method.IsStreamingServer() {
233
							metadata := responseFields.Get(0)
234
							assert.Equal(t, "metadata", metadata.TextName())
235
							assert.Equal(t, 1, int(metadata.Number()))
236
						} else {
237
							require.Equal(t, 1, responseFields.Len(), "unary responses should have exactly one field")
238

239
							messages := responseFields.Get(0)
240
							assert.Equal(t, "messages", messages.TextName())
241
							assert.Equal(t, 1, int(messages.Number()))
242

243
							reply := messages.Message()
244
							replyFields := reply.Fields()
245
							require.GreaterOrEqual(
246
								t,
247
								replyFields.Len(),
248
								1,
249
								"unary replies should have at least one field",
250
							)
251

252
							metadata := replyFields.Get(0)
253
							assert.Equal(t, "metadata", metadata.TextName())
254
							assert.Equal(t, 1, int(metadata.Number()))
255
						}
256
					},
257
				)
258
			}
259
		}
260
	}
261
}
262

263
//nolint:nakedret,gocyclo,errcheck,forcetypeassert
264
func getOptions(t *testing.T, descriptor protoreflect.Descriptor) (deprecated bool, version string) {
265
	switch opts := descriptor.Options().(type) {
266
	case *descriptorpb.EnumOptions:
267
		if opts != nil {
268
			deprecated = pointer.SafeDeref(opts.Deprecated)
269
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnum).(string)
270
		}
271
	case *descriptorpb.EnumValueOptions:
272
		if opts != nil {
273
			deprecated = pointer.SafeDeref(opts.Deprecated)
274
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedEnumValue).(string)
275
		}
276
	case *descriptorpb.MessageOptions:
277
		if opts != nil {
278
			deprecated = pointer.SafeDeref(opts.Deprecated)
279
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMessage).(string)
280
		}
281
	case *descriptorpb.FieldOptions:
282
		if opts != nil {
283
			deprecated = pointer.SafeDeref(opts.Deprecated)
284
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedField).(string)
285
		}
286
	case *descriptorpb.ServiceOptions:
287
		if opts != nil {
288
			deprecated = pointer.SafeDeref(opts.Deprecated)
289
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedService).(string)
290
		}
291
	case *descriptorpb.MethodOptions:
292
		if opts != nil {
293
			deprecated = pointer.SafeDeref(opts.Deprecated)
294
			version = protobuf.GetExtension(opts, common.E_RemoveDeprecatedMethod).(string)
295
		}
296

297
	default:
298
		t.Fatalf("unhandled %T", opts)
299
	}
300

301
	return
302
}
303

304
func testDeprecated(t *testing.T, descriptor protoreflect.Descriptor, currentVersion *config.VersionContract) {
305
	deprecated, version := getOptions(t, descriptor)
306

307
	assert.Equal(
308
		t, deprecated, version != "",
309
		"%s: `deprecated` and `remove_deprecated_XXX_in` options should be used together", descriptor.FullName(),
310
	)
311

312
	if !deprecated || version == "" {
313
		return
314
	}
315

316
	v, err := config.ParseContractFromVersion(version)
317
	require.NoError(t, err, "%s", descriptor.FullName())
318

319
	assert.True(t, v.Greater(currentVersion), "%s should be removed in this version", descriptor.FullName())
320
}
321

322
func testEnum(t *testing.T, enum protoreflect.EnumDescriptor, currentVersion *config.VersionContract) {
323
	testDeprecated(t, enum, currentVersion)
324

325
	values := enum.Values()
326
	for i := range values.Len() {
327
		testDeprecated(t, values.Get(i), currentVersion)
328
	}
329
}
330

331
func testMessage(t *testing.T, message protoreflect.MessageDescriptor, currentVersion *config.VersionContract) {
332
	testDeprecated(t, message, currentVersion)
333

334
	fields := message.Fields()
335
	for i := range fields.Len() {
336
		testDeprecated(t, fields.Get(i), currentVersion)
337
	}
338

339
	oneofs := message.Oneofs()
340
	for i := range oneofs.Len() {
341
		testDeprecated(t, oneofs.Get(i), currentVersion)
342
	}
343

344
	enums := message.Enums()
345
	for i := range enums.Len() {
346
		testEnum(t, enums.Get(i), currentVersion)
347
	}
348

349
	// test nested messages
350
	messages := message.Messages()
351
	for i := range messages.Len() {
352
		testMessage(t, messages.Get(i), currentVersion)
353
	}
354
}
355

356
func TestDeprecatedAPIs(t *testing.T) {
357
	currentVersion, err := config.ParseContractFromVersion(version.Tag)
358
	require.NoError(t, err)
359

360
	for _, file := range []protoreflect.FileDescriptor{
361
		common.File_common_common_proto,
362
		cluster.File_cluster_cluster_proto,
363
		inspect.File_inspect_inspect_proto,
364
		machine.File_machine_machine_proto,
365
		security.File_security_security_proto,
366
		storage.File_storage_storage_proto,
367
		time.File_time_time_proto,
368
	} {
369
		enums := file.Enums()
370
		for i := range enums.Len() {
371
			testEnum(t, enums.Get(i), currentVersion)
372
		}
373

374
		messages := file.Messages()
375
		for i := range messages.Len() {
376
			testMessage(t, messages.Get(i), currentVersion)
377
		}
378

379
		services := file.Services()
380
		for i := range services.Len() {
381
			service := services.Get(i)
382
			testDeprecated(t, service, currentVersion)
383

384
			methods := service.Methods()
385
			for j := range methods.Len() {
386
				method := methods.Get(j)
387
				testDeprecated(t, method, currentVersion)
388

389
				message := method.Input()
390
				testMessage(t, message, currentVersion)
391

392
				message = method.Output()
393
				testMessage(t, message, currentVersion)
394
			}
395
		}
396
	}
397
}
398

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.