go-bot

Форк
0
294 строки · 8.5 Кб
1
// Copyright 2018 The Go Authors. All rights reserved.
2
// Use of this source code is governed by a BSD-style
3
// license that can be found in the LICENSE file.
4

5
package proto
6

7
import (
8
	"google.golang.org/protobuf/encoding/protowire"
9
	"google.golang.org/protobuf/internal/encoding/messageset"
10
	"google.golang.org/protobuf/internal/errors"
11
	"google.golang.org/protobuf/internal/flags"
12
	"google.golang.org/protobuf/internal/genid"
13
	"google.golang.org/protobuf/internal/pragma"
14
	"google.golang.org/protobuf/reflect/protoreflect"
15
	"google.golang.org/protobuf/reflect/protoregistry"
16
	"google.golang.org/protobuf/runtime/protoiface"
17
)
18

19
// UnmarshalOptions configures the unmarshaler.
20
//
21
// Example usage:
22
//
23
//	err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
24
type UnmarshalOptions struct {
25
	pragma.NoUnkeyedLiterals
26

27
	// Merge merges the input into the destination message.
28
	// The default behavior is to always reset the message before unmarshaling,
29
	// unless Merge is specified.
30
	Merge bool
31

32
	// AllowPartial accepts input for messages that will result in missing
33
	// required fields. If AllowPartial is false (the default), Unmarshal will
34
	// return an error if there are any missing required fields.
35
	AllowPartial bool
36

37
	// If DiscardUnknown is set, unknown fields are ignored.
38
	DiscardUnknown bool
39

40
	// Resolver is used for looking up types when unmarshaling extension fields.
41
	// If nil, this defaults to using protoregistry.GlobalTypes.
42
	Resolver interface {
43
		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
44
		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
45
	}
46

47
	// RecursionLimit limits how deeply messages may be nested.
48
	// If zero, a default limit is applied.
49
	RecursionLimit int
50
}
51

52
// Unmarshal parses the wire-format message in b and places the result in m.
53
// The provided message must be mutable (e.g., a non-nil pointer to a message).
54
func Unmarshal(b []byte, m Message) error {
55
	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
56
	return err
57
}
58

59
// Unmarshal parses the wire-format message in b and places the result in m.
60
// The provided message must be mutable (e.g., a non-nil pointer to a message).
61
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
62
	if o.RecursionLimit == 0 {
63
		o.RecursionLimit = protowire.DefaultRecursionLimit
64
	}
65
	_, err := o.unmarshal(b, m.ProtoReflect())
66
	return err
67
}
68

69
// UnmarshalState parses a wire-format message and places the result in m.
70
//
71
// This method permits fine-grained control over the unmarshaler.
72
// Most users should use Unmarshal instead.
73
func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
74
	if o.RecursionLimit == 0 {
75
		o.RecursionLimit = protowire.DefaultRecursionLimit
76
	}
77
	return o.unmarshal(in.Buf, in.Message)
78
}
79

80
// unmarshal is a centralized function that all unmarshal operations go through.
81
// For profiling purposes, avoid changing the name of this function or
82
// introducing other code paths for unmarshal that do not go through this.
83
func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
84
	if o.Resolver == nil {
85
		o.Resolver = protoregistry.GlobalTypes
86
	}
87
	if !o.Merge {
88
		Reset(m.Interface())
89
	}
90
	allowPartial := o.AllowPartial
91
	o.Merge = true
92
	o.AllowPartial = true
93
	methods := protoMethods(m)
94
	if methods != nil && methods.Unmarshal != nil &&
95
		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
96
		in := protoiface.UnmarshalInput{
97
			Message:  m,
98
			Buf:      b,
99
			Resolver: o.Resolver,
100
			Depth:    o.RecursionLimit,
101
		}
102
		if o.DiscardUnknown {
103
			in.Flags |= protoiface.UnmarshalDiscardUnknown
104
		}
105
		out, err = methods.Unmarshal(in)
106
	} else {
107
		o.RecursionLimit--
108
		if o.RecursionLimit < 0 {
109
			return out, errors.New("exceeded max recursion depth")
110
		}
111
		err = o.unmarshalMessageSlow(b, m)
112
	}
113
	if err != nil {
114
		return out, err
115
	}
116
	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
117
		return out, nil
118
	}
119
	return out, checkInitialized(m)
120
}
121

122
func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
123
	_, err := o.unmarshal(b, m)
124
	return err
125
}
126

127
func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
128
	md := m.Descriptor()
129
	if messageset.IsMessageSet(md) {
130
		return o.unmarshalMessageSet(b, m)
131
	}
132
	fields := md.Fields()
133
	for len(b) > 0 {
134
		// Parse the tag (field number and wire type).
135
		num, wtyp, tagLen := protowire.ConsumeTag(b)
136
		if tagLen < 0 {
137
			return errDecode
138
		}
139
		if num > protowire.MaxValidNumber {
140
			return errDecode
141
		}
142

143
		// Find the field descriptor for this field number.
144
		fd := fields.ByNumber(num)
145
		if fd == nil && md.ExtensionRanges().Has(num) {
146
			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
147
			if err != nil && err != protoregistry.NotFound {
148
				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
149
			}
150
			if extType != nil {
151
				fd = extType.TypeDescriptor()
152
			}
153
		}
154
		var err error
155
		if fd == nil {
156
			err = errUnknown
157
		} else if flags.ProtoLegacy {
158
			if fd.IsWeak() && fd.Message().IsPlaceholder() {
159
				err = errUnknown // weak referent is not linked in
160
			}
161
		}
162

163
		// Parse the field value.
164
		var valLen int
165
		switch {
166
		case err != nil:
167
		case fd.IsList():
168
			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
169
		case fd.IsMap():
170
			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
171
		default:
172
			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
173
		}
174
		if err != nil {
175
			if err != errUnknown {
176
				return err
177
			}
178
			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
179
			if valLen < 0 {
180
				return errDecode
181
			}
182
			if !o.DiscardUnknown {
183
				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
184
			}
185
		}
186
		b = b[tagLen+valLen:]
187
	}
188
	return nil
189
}
190

191
func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
192
	v, n, err := o.unmarshalScalar(b, wtyp, fd)
193
	if err != nil {
194
		return 0, err
195
	}
196
	switch fd.Kind() {
197
	case protoreflect.GroupKind, protoreflect.MessageKind:
198
		m2 := m.Mutable(fd).Message()
199
		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
200
			return n, err
201
		}
202
	default:
203
		// Non-message scalars replace the previous value.
204
		m.Set(fd, v)
205
	}
206
	return n, nil
207
}
208

209
func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
210
	if wtyp != protowire.BytesType {
211
		return 0, errUnknown
212
	}
213
	b, n = protowire.ConsumeBytes(b)
214
	if n < 0 {
215
		return 0, errDecode
216
	}
217
	var (
218
		keyField = fd.MapKey()
219
		valField = fd.MapValue()
220
		key      protoreflect.Value
221
		val      protoreflect.Value
222
		haveKey  bool
223
		haveVal  bool
224
	)
225
	switch valField.Kind() {
226
	case protoreflect.GroupKind, protoreflect.MessageKind:
227
		val = mapv.NewValue()
228
	}
229
	// Map entries are represented as a two-element message with fields
230
	// containing the key and value.
231
	for len(b) > 0 {
232
		num, wtyp, n := protowire.ConsumeTag(b)
233
		if n < 0 {
234
			return 0, errDecode
235
		}
236
		if num > protowire.MaxValidNumber {
237
			return 0, errDecode
238
		}
239
		b = b[n:]
240
		err = errUnknown
241
		switch num {
242
		case genid.MapEntry_Key_field_number:
243
			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
244
			if err != nil {
245
				break
246
			}
247
			haveKey = true
248
		case genid.MapEntry_Value_field_number:
249
			var v protoreflect.Value
250
			v, n, err = o.unmarshalScalar(b, wtyp, valField)
251
			if err != nil {
252
				break
253
			}
254
			switch valField.Kind() {
255
			case protoreflect.GroupKind, protoreflect.MessageKind:
256
				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
257
					return 0, err
258
				}
259
			default:
260
				val = v
261
			}
262
			haveVal = true
263
		}
264
		if err == errUnknown {
265
			n = protowire.ConsumeFieldValue(num, wtyp, b)
266
			if n < 0 {
267
				return 0, errDecode
268
			}
269
		} else if err != nil {
270
			return 0, err
271
		}
272
		b = b[n:]
273
	}
274
	// Every map entry should have entries for key and value, but this is not strictly required.
275
	if !haveKey {
276
		key = keyField.Default()
277
	}
278
	if !haveVal {
279
		switch valField.Kind() {
280
		case protoreflect.GroupKind, protoreflect.MessageKind:
281
		default:
282
			val = valField.Default()
283
		}
284
	}
285
	mapv.Set(key.MapKey(), val)
286
	return n, nil
287
}
288

289
// errUnknown is used internally to indicate fields which should be added
290
// to the unknown field set of a message. It is never returned from an exported
291
// function.
292
var errUnknown = errors.New("BUG: internal error (unknown)")
293

294
var errDecode = errors.New("cannot parse invalid wire-format data")
295

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.