go-bot
294 строки · 8.5 Кб
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8"google.golang.org/protobuf/encoding/protowire"
9"google.golang.org/protobuf/internal/encoding/messageset"
10"google.golang.org/protobuf/internal/errors"
11"google.golang.org/protobuf/internal/flags"
12"google.golang.org/protobuf/internal/genid"
13"google.golang.org/protobuf/internal/pragma"
14"google.golang.org/protobuf/reflect/protoreflect"
15"google.golang.org/protobuf/reflect/protoregistry"
16"google.golang.org/protobuf/runtime/protoiface"
17)
18
19// UnmarshalOptions configures the unmarshaler.
20//
21// Example usage:
22//
23// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
24type UnmarshalOptions struct {
25pragma.NoUnkeyedLiterals
26
27// Merge merges the input into the destination message.
28// The default behavior is to always reset the message before unmarshaling,
29// unless Merge is specified.
30Merge bool
31
32// AllowPartial accepts input for messages that will result in missing
33// required fields. If AllowPartial is false (the default), Unmarshal will
34// return an error if there are any missing required fields.
35AllowPartial bool
36
37// If DiscardUnknown is set, unknown fields are ignored.
38DiscardUnknown bool
39
40// Resolver is used for looking up types when unmarshaling extension fields.
41// If nil, this defaults to using protoregistry.GlobalTypes.
42Resolver interface {
43FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
44FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
45}
46
47// RecursionLimit limits how deeply messages may be nested.
48// If zero, a default limit is applied.
49RecursionLimit int
50}
51
52// Unmarshal parses the wire-format message in b and places the result in m.
53// The provided message must be mutable (e.g., a non-nil pointer to a message).
54func Unmarshal(b []byte, m Message) error {
55_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
56return err
57}
58
59// Unmarshal parses the wire-format message in b and places the result in m.
60// The provided message must be mutable (e.g., a non-nil pointer to a message).
61func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
62if o.RecursionLimit == 0 {
63o.RecursionLimit = protowire.DefaultRecursionLimit
64}
65_, err := o.unmarshal(b, m.ProtoReflect())
66return err
67}
68
69// UnmarshalState parses a wire-format message and places the result in m.
70//
71// This method permits fine-grained control over the unmarshaler.
72// Most users should use Unmarshal instead.
73func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
74if o.RecursionLimit == 0 {
75o.RecursionLimit = protowire.DefaultRecursionLimit
76}
77return o.unmarshal(in.Buf, in.Message)
78}
79
80// unmarshal is a centralized function that all unmarshal operations go through.
81// For profiling purposes, avoid changing the name of this function or
82// introducing other code paths for unmarshal that do not go through this.
83func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
84if o.Resolver == nil {
85o.Resolver = protoregistry.GlobalTypes
86}
87if !o.Merge {
88Reset(m.Interface())
89}
90allowPartial := o.AllowPartial
91o.Merge = true
92o.AllowPartial = true
93methods := protoMethods(m)
94if methods != nil && methods.Unmarshal != nil &&
95!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
96in := protoiface.UnmarshalInput{
97Message: m,
98Buf: b,
99Resolver: o.Resolver,
100Depth: o.RecursionLimit,
101}
102if o.DiscardUnknown {
103in.Flags |= protoiface.UnmarshalDiscardUnknown
104}
105out, err = methods.Unmarshal(in)
106} else {
107o.RecursionLimit--
108if o.RecursionLimit < 0 {
109return out, errors.New("exceeded max recursion depth")
110}
111err = o.unmarshalMessageSlow(b, m)
112}
113if err != nil {
114return out, err
115}
116if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
117return out, nil
118}
119return out, checkInitialized(m)
120}
121
122func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
123_, err := o.unmarshal(b, m)
124return err
125}
126
127func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
128md := m.Descriptor()
129if messageset.IsMessageSet(md) {
130return o.unmarshalMessageSet(b, m)
131}
132fields := md.Fields()
133for len(b) > 0 {
134// Parse the tag (field number and wire type).
135num, wtyp, tagLen := protowire.ConsumeTag(b)
136if tagLen < 0 {
137return errDecode
138}
139if num > protowire.MaxValidNumber {
140return errDecode
141}
142
143// Find the field descriptor for this field number.
144fd := fields.ByNumber(num)
145if fd == nil && md.ExtensionRanges().Has(num) {
146extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
147if err != nil && err != protoregistry.NotFound {
148return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
149}
150if extType != nil {
151fd = extType.TypeDescriptor()
152}
153}
154var err error
155if fd == nil {
156err = errUnknown
157} else if flags.ProtoLegacy {
158if fd.IsWeak() && fd.Message().IsPlaceholder() {
159err = errUnknown // weak referent is not linked in
160}
161}
162
163// Parse the field value.
164var valLen int
165switch {
166case err != nil:
167case fd.IsList():
168valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
169case fd.IsMap():
170valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
171default:
172valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
173}
174if err != nil {
175if err != errUnknown {
176return err
177}
178valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
179if valLen < 0 {
180return errDecode
181}
182if !o.DiscardUnknown {
183m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
184}
185}
186b = b[tagLen+valLen:]
187}
188return nil
189}
190
191func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
192v, n, err := o.unmarshalScalar(b, wtyp, fd)
193if err != nil {
194return 0, err
195}
196switch fd.Kind() {
197case protoreflect.GroupKind, protoreflect.MessageKind:
198m2 := m.Mutable(fd).Message()
199if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
200return n, err
201}
202default:
203// Non-message scalars replace the previous value.
204m.Set(fd, v)
205}
206return n, nil
207}
208
209func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
210if wtyp != protowire.BytesType {
211return 0, errUnknown
212}
213b, n = protowire.ConsumeBytes(b)
214if n < 0 {
215return 0, errDecode
216}
217var (
218keyField = fd.MapKey()
219valField = fd.MapValue()
220key protoreflect.Value
221val protoreflect.Value
222haveKey bool
223haveVal bool
224)
225switch valField.Kind() {
226case protoreflect.GroupKind, protoreflect.MessageKind:
227val = mapv.NewValue()
228}
229// Map entries are represented as a two-element message with fields
230// containing the key and value.
231for len(b) > 0 {
232num, wtyp, n := protowire.ConsumeTag(b)
233if n < 0 {
234return 0, errDecode
235}
236if num > protowire.MaxValidNumber {
237return 0, errDecode
238}
239b = b[n:]
240err = errUnknown
241switch num {
242case genid.MapEntry_Key_field_number:
243key, n, err = o.unmarshalScalar(b, wtyp, keyField)
244if err != nil {
245break
246}
247haveKey = true
248case genid.MapEntry_Value_field_number:
249var v protoreflect.Value
250v, n, err = o.unmarshalScalar(b, wtyp, valField)
251if err != nil {
252break
253}
254switch valField.Kind() {
255case protoreflect.GroupKind, protoreflect.MessageKind:
256if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
257return 0, err
258}
259default:
260val = v
261}
262haveVal = true
263}
264if err == errUnknown {
265n = protowire.ConsumeFieldValue(num, wtyp, b)
266if n < 0 {
267return 0, errDecode
268}
269} else if err != nil {
270return 0, err
271}
272b = b[n:]
273}
274// Every map entry should have entries for key and value, but this is not strictly required.
275if !haveKey {
276key = keyField.Default()
277}
278if !haveVal {
279switch valField.Kind() {
280case protoreflect.GroupKind, protoreflect.MessageKind:
281default:
282val = valField.Default()
283}
284}
285mapv.Set(key.MapKey(), val)
286return n, nil
287}
288
289// errUnknown is used internally to indicate fields which should be added
290// to the unknown field set of a message. It is never returned from an exported
291// function.
292var errUnknown = errors.New("BUG: internal error (unknown)")
293
294var errDecode = errors.New("cannot parse invalid wire-format data")
295