podman

Форк
0
576 строк · 15.1 Кб
1
// Copyright 2019 The Go Authors. All rights reserved.
2
// Use of this source code is governed by a BSD-style
3
// license that can be found in the LICENSE file.
4

5
package impl
6

7
import (
8
	"fmt"
9
	"math"
10
	"math/bits"
11
	"reflect"
12
	"unicode/utf8"
13

14
	"google.golang.org/protobuf/encoding/protowire"
15
	"google.golang.org/protobuf/internal/encoding/messageset"
16
	"google.golang.org/protobuf/internal/flags"
17
	"google.golang.org/protobuf/internal/genid"
18
	"google.golang.org/protobuf/internal/strs"
19
	"google.golang.org/protobuf/reflect/protoreflect"
20
	"google.golang.org/protobuf/reflect/protoregistry"
21
	"google.golang.org/protobuf/runtime/protoiface"
22
)
23

24
// ValidationStatus is the result of validating the wire-format encoding of a message.
25
type ValidationStatus int
26

27
const (
28
	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
29
	// The validator was unable to render a judgement.
30
	//
31
	// The only causes of this status are an aberrant message type appearing somewhere
32
	// in the message or a failure in the extension resolver.
33
	ValidationUnknown ValidationStatus = iota + 1
34

35
	// ValidationInvalid indicates that unmarshaling the message will fail.
36
	ValidationInvalid
37

38
	// ValidationValid indicates that unmarshaling the message will succeed.
39
	ValidationValid
40
)
41

42
func (v ValidationStatus) String() string {
43
	switch v {
44
	case ValidationUnknown:
45
		return "ValidationUnknown"
46
	case ValidationInvalid:
47
		return "ValidationInvalid"
48
	case ValidationValid:
49
		return "ValidationValid"
50
	default:
51
		return fmt.Sprintf("ValidationStatus(%d)", int(v))
52
	}
53
}
54

55
// Validate determines whether the contents of the buffer are a valid wire encoding
56
// of the message type.
57
//
58
// This function is exposed for testing.
59
func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
60
	mi, ok := mt.(*MessageInfo)
61
	if !ok {
62
		return out, ValidationUnknown
63
	}
64
	if in.Resolver == nil {
65
		in.Resolver = protoregistry.GlobalTypes
66
	}
67
	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
68
		flags:    in.Flags,
69
		resolver: in.Resolver,
70
	})
71
	if o.initialized {
72
		out.Flags |= protoiface.UnmarshalInitialized
73
	}
74
	return out, st
75
}
76

77
type validationInfo struct {
78
	mi               *MessageInfo
79
	typ              validationType
80
	keyType, valType validationType
81

82
	// For non-required fields, requiredBit is 0.
83
	//
84
	// For required fields, requiredBit's nth bit is set, where n is a
85
	// unique index in the range [0, MessageInfo.numRequiredFields).
86
	//
87
	// If there are more than 64 required fields, requiredBit is 0.
88
	requiredBit uint64
89
}
90

91
type validationType uint8
92

93
const (
94
	validationTypeOther validationType = iota
95
	validationTypeMessage
96
	validationTypeGroup
97
	validationTypeMap
98
	validationTypeRepeatedVarint
99
	validationTypeRepeatedFixed32
100
	validationTypeRepeatedFixed64
101
	validationTypeVarint
102
	validationTypeFixed32
103
	validationTypeFixed64
104
	validationTypeBytes
105
	validationTypeUTF8String
106
	validationTypeMessageSetItem
107
)
108

109
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
110
	var vi validationInfo
111
	switch {
112
	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
113
		switch fd.Kind() {
114
		case protoreflect.MessageKind:
115
			vi.typ = validationTypeMessage
116
			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
117
				vi.mi = getMessageInfo(ot.Field(0).Type)
118
			}
119
		case protoreflect.GroupKind:
120
			vi.typ = validationTypeGroup
121
			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
122
				vi.mi = getMessageInfo(ot.Field(0).Type)
123
			}
124
		case protoreflect.StringKind:
125
			if strs.EnforceUTF8(fd) {
126
				vi.typ = validationTypeUTF8String
127
			}
128
		}
129
	default:
130
		vi = newValidationInfo(fd, ft)
131
	}
132
	if fd.Cardinality() == protoreflect.Required {
133
		// Avoid overflow. The required field check is done with a 64-bit mask, with
134
		// any message containing more than 64 required fields always reported as
135
		// potentially uninitialized, so it is not important to get a precise count
136
		// of the required fields past 64.
137
		if mi.numRequiredFields < math.MaxUint8 {
138
			mi.numRequiredFields++
139
			vi.requiredBit = 1 << (mi.numRequiredFields - 1)
140
		}
141
	}
142
	return vi
143
}
144

145
func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
146
	var vi validationInfo
147
	switch {
148
	case fd.IsList():
149
		switch fd.Kind() {
150
		case protoreflect.MessageKind:
151
			vi.typ = validationTypeMessage
152
			if ft.Kind() == reflect.Slice {
153
				vi.mi = getMessageInfo(ft.Elem())
154
			}
155
		case protoreflect.GroupKind:
156
			vi.typ = validationTypeGroup
157
			if ft.Kind() == reflect.Slice {
158
				vi.mi = getMessageInfo(ft.Elem())
159
			}
160
		case protoreflect.StringKind:
161
			vi.typ = validationTypeBytes
162
			if strs.EnforceUTF8(fd) {
163
				vi.typ = validationTypeUTF8String
164
			}
165
		default:
166
			switch wireTypes[fd.Kind()] {
167
			case protowire.VarintType:
168
				vi.typ = validationTypeRepeatedVarint
169
			case protowire.Fixed32Type:
170
				vi.typ = validationTypeRepeatedFixed32
171
			case protowire.Fixed64Type:
172
				vi.typ = validationTypeRepeatedFixed64
173
			}
174
		}
175
	case fd.IsMap():
176
		vi.typ = validationTypeMap
177
		switch fd.MapKey().Kind() {
178
		case protoreflect.StringKind:
179
			if strs.EnforceUTF8(fd) {
180
				vi.keyType = validationTypeUTF8String
181
			}
182
		}
183
		switch fd.MapValue().Kind() {
184
		case protoreflect.MessageKind:
185
			vi.valType = validationTypeMessage
186
			if ft.Kind() == reflect.Map {
187
				vi.mi = getMessageInfo(ft.Elem())
188
			}
189
		case protoreflect.StringKind:
190
			if strs.EnforceUTF8(fd) {
191
				vi.valType = validationTypeUTF8String
192
			}
193
		}
194
	default:
195
		switch fd.Kind() {
196
		case protoreflect.MessageKind:
197
			vi.typ = validationTypeMessage
198
			if !fd.IsWeak() {
199
				vi.mi = getMessageInfo(ft)
200
			}
201
		case protoreflect.GroupKind:
202
			vi.typ = validationTypeGroup
203
			vi.mi = getMessageInfo(ft)
204
		case protoreflect.StringKind:
205
			vi.typ = validationTypeBytes
206
			if strs.EnforceUTF8(fd) {
207
				vi.typ = validationTypeUTF8String
208
			}
209
		default:
210
			switch wireTypes[fd.Kind()] {
211
			case protowire.VarintType:
212
				vi.typ = validationTypeVarint
213
			case protowire.Fixed32Type:
214
				vi.typ = validationTypeFixed32
215
			case protowire.Fixed64Type:
216
				vi.typ = validationTypeFixed64
217
			case protowire.BytesType:
218
				vi.typ = validationTypeBytes
219
			}
220
		}
221
	}
222
	return vi
223
}
224

225
func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
226
	mi.init()
227
	type validationState struct {
228
		typ              validationType
229
		keyType, valType validationType
230
		endGroup         protowire.Number
231
		mi               *MessageInfo
232
		tail             []byte
233
		requiredMask     uint64
234
	}
235

236
	// Pre-allocate some slots to avoid repeated slice reallocation.
237
	states := make([]validationState, 0, 16)
238
	states = append(states, validationState{
239
		typ: validationTypeMessage,
240
		mi:  mi,
241
	})
242
	if groupTag > 0 {
243
		states[0].typ = validationTypeGroup
244
		states[0].endGroup = groupTag
245
	}
246
	initialized := true
247
	start := len(b)
248
State:
249
	for len(states) > 0 {
250
		st := &states[len(states)-1]
251
		for len(b) > 0 {
252
			// Parse the tag (field number and wire type).
253
			var tag uint64
254
			if b[0] < 0x80 {
255
				tag = uint64(b[0])
256
				b = b[1:]
257
			} else if len(b) >= 2 && b[1] < 128 {
258
				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
259
				b = b[2:]
260
			} else {
261
				var n int
262
				tag, n = protowire.ConsumeVarint(b)
263
				if n < 0 {
264
					return out, ValidationInvalid
265
				}
266
				b = b[n:]
267
			}
268
			var num protowire.Number
269
			if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
270
				return out, ValidationInvalid
271
			} else {
272
				num = protowire.Number(n)
273
			}
274
			wtyp := protowire.Type(tag & 7)
275

276
			if wtyp == protowire.EndGroupType {
277
				if st.endGroup == num {
278
					goto PopState
279
				}
280
				return out, ValidationInvalid
281
			}
282
			var vi validationInfo
283
			switch {
284
			case st.typ == validationTypeMap:
285
				switch num {
286
				case genid.MapEntry_Key_field_number:
287
					vi.typ = st.keyType
288
				case genid.MapEntry_Value_field_number:
289
					vi.typ = st.valType
290
					vi.mi = st.mi
291
					vi.requiredBit = 1
292
				}
293
			case flags.ProtoLegacy && st.mi.isMessageSet:
294
				switch num {
295
				case messageset.FieldItem:
296
					vi.typ = validationTypeMessageSetItem
297
				}
298
			default:
299
				var f *coderFieldInfo
300
				if int(num) < len(st.mi.denseCoderFields) {
301
					f = st.mi.denseCoderFields[num]
302
				} else {
303
					f = st.mi.coderFields[num]
304
				}
305
				if f != nil {
306
					vi = f.validation
307
					if vi.typ == validationTypeMessage && vi.mi == nil {
308
						// Probable weak field.
309
						//
310
						// TODO: Consider storing the results of this lookup somewhere
311
						// rather than recomputing it on every validation.
312
						fd := st.mi.Desc.Fields().ByNumber(num)
313
						if fd == nil || !fd.IsWeak() {
314
							break
315
						}
316
						messageName := fd.Message().FullName()
317
						messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
318
						switch err {
319
						case nil:
320
							vi.mi, _ = messageType.(*MessageInfo)
321
						case protoregistry.NotFound:
322
							vi.typ = validationTypeBytes
323
						default:
324
							return out, ValidationUnknown
325
						}
326
					}
327
					break
328
				}
329
				// Possible extension field.
330
				//
331
				// TODO: We should return ValidationUnknown when:
332
				//   1. The resolver is not frozen. (More extensions may be added to it.)
333
				//   2. The resolver returns preg.NotFound.
334
				// In this case, a type added to the resolver in the future could cause
335
				// unmarshaling to begin failing. Supporting this requires some way to
336
				// determine if the resolver is frozen.
337
				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
338
				if err != nil && err != protoregistry.NotFound {
339
					return out, ValidationUnknown
340
				}
341
				if err == nil {
342
					vi = getExtensionFieldInfo(xt).validation
343
				}
344
			}
345
			if vi.requiredBit != 0 {
346
				// Check that the field has a compatible wire type.
347
				// We only need to consider non-repeated field types,
348
				// since repeated fields (and maps) can never be required.
349
				ok := false
350
				switch vi.typ {
351
				case validationTypeVarint:
352
					ok = wtyp == protowire.VarintType
353
				case validationTypeFixed32:
354
					ok = wtyp == protowire.Fixed32Type
355
				case validationTypeFixed64:
356
					ok = wtyp == protowire.Fixed64Type
357
				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
358
					ok = wtyp == protowire.BytesType
359
				case validationTypeGroup:
360
					ok = wtyp == protowire.StartGroupType
361
				}
362
				if ok {
363
					st.requiredMask |= vi.requiredBit
364
				}
365
			}
366

367
			switch wtyp {
368
			case protowire.VarintType:
369
				if len(b) >= 10 {
370
					switch {
371
					case b[0] < 0x80:
372
						b = b[1:]
373
					case b[1] < 0x80:
374
						b = b[2:]
375
					case b[2] < 0x80:
376
						b = b[3:]
377
					case b[3] < 0x80:
378
						b = b[4:]
379
					case b[4] < 0x80:
380
						b = b[5:]
381
					case b[5] < 0x80:
382
						b = b[6:]
383
					case b[6] < 0x80:
384
						b = b[7:]
385
					case b[7] < 0x80:
386
						b = b[8:]
387
					case b[8] < 0x80:
388
						b = b[9:]
389
					case b[9] < 0x80 && b[9] < 2:
390
						b = b[10:]
391
					default:
392
						return out, ValidationInvalid
393
					}
394
				} else {
395
					switch {
396
					case len(b) > 0 && b[0] < 0x80:
397
						b = b[1:]
398
					case len(b) > 1 && b[1] < 0x80:
399
						b = b[2:]
400
					case len(b) > 2 && b[2] < 0x80:
401
						b = b[3:]
402
					case len(b) > 3 && b[3] < 0x80:
403
						b = b[4:]
404
					case len(b) > 4 && b[4] < 0x80:
405
						b = b[5:]
406
					case len(b) > 5 && b[5] < 0x80:
407
						b = b[6:]
408
					case len(b) > 6 && b[6] < 0x80:
409
						b = b[7:]
410
					case len(b) > 7 && b[7] < 0x80:
411
						b = b[8:]
412
					case len(b) > 8 && b[8] < 0x80:
413
						b = b[9:]
414
					case len(b) > 9 && b[9] < 2:
415
						b = b[10:]
416
					default:
417
						return out, ValidationInvalid
418
					}
419
				}
420
				continue State
421
			case protowire.BytesType:
422
				var size uint64
423
				if len(b) >= 1 && b[0] < 0x80 {
424
					size = uint64(b[0])
425
					b = b[1:]
426
				} else if len(b) >= 2 && b[1] < 128 {
427
					size = uint64(b[0]&0x7f) + uint64(b[1])<<7
428
					b = b[2:]
429
				} else {
430
					var n int
431
					size, n = protowire.ConsumeVarint(b)
432
					if n < 0 {
433
						return out, ValidationInvalid
434
					}
435
					b = b[n:]
436
				}
437
				if size > uint64(len(b)) {
438
					return out, ValidationInvalid
439
				}
440
				v := b[:size]
441
				b = b[size:]
442
				switch vi.typ {
443
				case validationTypeMessage:
444
					if vi.mi == nil {
445
						return out, ValidationUnknown
446
					}
447
					vi.mi.init()
448
					fallthrough
449
				case validationTypeMap:
450
					if vi.mi != nil {
451
						vi.mi.init()
452
					}
453
					states = append(states, validationState{
454
						typ:     vi.typ,
455
						keyType: vi.keyType,
456
						valType: vi.valType,
457
						mi:      vi.mi,
458
						tail:    b,
459
					})
460
					b = v
461
					continue State
462
				case validationTypeRepeatedVarint:
463
					// Packed field.
464
					for len(v) > 0 {
465
						_, n := protowire.ConsumeVarint(v)
466
						if n < 0 {
467
							return out, ValidationInvalid
468
						}
469
						v = v[n:]
470
					}
471
				case validationTypeRepeatedFixed32:
472
					// Packed field.
473
					if len(v)%4 != 0 {
474
						return out, ValidationInvalid
475
					}
476
				case validationTypeRepeatedFixed64:
477
					// Packed field.
478
					if len(v)%8 != 0 {
479
						return out, ValidationInvalid
480
					}
481
				case validationTypeUTF8String:
482
					if !utf8.Valid(v) {
483
						return out, ValidationInvalid
484
					}
485
				}
486
			case protowire.Fixed32Type:
487
				if len(b) < 4 {
488
					return out, ValidationInvalid
489
				}
490
				b = b[4:]
491
			case protowire.Fixed64Type:
492
				if len(b) < 8 {
493
					return out, ValidationInvalid
494
				}
495
				b = b[8:]
496
			case protowire.StartGroupType:
497
				switch {
498
				case vi.typ == validationTypeGroup:
499
					if vi.mi == nil {
500
						return out, ValidationUnknown
501
					}
502
					vi.mi.init()
503
					states = append(states, validationState{
504
						typ:      validationTypeGroup,
505
						mi:       vi.mi,
506
						endGroup: num,
507
					})
508
					continue State
509
				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
510
					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
511
					if err != nil {
512
						return out, ValidationInvalid
513
					}
514
					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
515
					switch {
516
					case err == protoregistry.NotFound:
517
						b = b[n:]
518
					case err != nil:
519
						return out, ValidationUnknown
520
					default:
521
						xvi := getExtensionFieldInfo(xt).validation
522
						if xvi.mi != nil {
523
							xvi.mi.init()
524
						}
525
						states = append(states, validationState{
526
							typ:  xvi.typ,
527
							mi:   xvi.mi,
528
							tail: b[n:],
529
						})
530
						b = v
531
						continue State
532
					}
533
				default:
534
					n := protowire.ConsumeFieldValue(num, wtyp, b)
535
					if n < 0 {
536
						return out, ValidationInvalid
537
					}
538
					b = b[n:]
539
				}
540
			default:
541
				return out, ValidationInvalid
542
			}
543
		}
544
		if st.endGroup != 0 {
545
			return out, ValidationInvalid
546
		}
547
		if len(b) != 0 {
548
			return out, ValidationInvalid
549
		}
550
		b = st.tail
551
	PopState:
552
		numRequiredFields := 0
553
		switch st.typ {
554
		case validationTypeMessage, validationTypeGroup:
555
			numRequiredFields = int(st.mi.numRequiredFields)
556
		case validationTypeMap:
557
			// If this is a map field with a message value that contains
558
			// required fields, require that the value be present.
559
			if st.mi != nil && st.mi.numRequiredFields > 0 {
560
				numRequiredFields = 1
561
			}
562
		}
563
		// If there are more than 64 required fields, this check will
564
		// always fail and we will report that the message is potentially
565
		// uninitialized.
566
		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
567
			initialized = false
568
		}
569
		states = states[:len(states)-1]
570
	}
571
	out.n = start - len(b)
572
	if initialized {
573
		out.initialized = true
574
	}
575
	return out, ValidationValid
576
}
577

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.