podman

Форк
0
290 строк · 9.2 Кб
1
// Copyright 2013-2023 The Cobra Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//      http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
package cobra
16

17
import (
18
	"fmt"
19
	"sort"
20
	"strings"
21

22
	flag "github.com/spf13/pflag"
23
)
24

25
const (
26
	requiredAsGroup   = "cobra_annotation_required_if_others_set"
27
	oneRequired       = "cobra_annotation_one_required"
28
	mutuallyExclusive = "cobra_annotation_mutually_exclusive"
29
)
30

31
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
32
// if the command is invoked with a subset (but not all) of the given flags.
33
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
34
	c.mergePersistentFlags()
35
	for _, v := range flagNames {
36
		f := c.Flags().Lookup(v)
37
		if f == nil {
38
			panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
39
		}
40
		if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
41
			// Only errs if the flag isn't found.
42
			panic(err)
43
		}
44
	}
45
}
46

47
// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
48
// if the command is invoked without at least one flag from the given set of flags.
49
func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
50
	c.mergePersistentFlags()
51
	for _, v := range flagNames {
52
		f := c.Flags().Lookup(v)
53
		if f == nil {
54
			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
55
		}
56
		if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
57
			// Only errs if the flag isn't found.
58
			panic(err)
59
		}
60
	}
61
}
62

63
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
64
// if the command is invoked with more than one flag from the given set of flags.
65
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
66
	c.mergePersistentFlags()
67
	for _, v := range flagNames {
68
		f := c.Flags().Lookup(v)
69
		if f == nil {
70
			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
71
		}
72
		// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
73
		if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
74
			panic(err)
75
		}
76
	}
77
}
78

79
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
80
// first error encountered.
81
func (c *Command) ValidateFlagGroups() error {
82
	if c.DisableFlagParsing {
83
		return nil
84
	}
85

86
	flags := c.Flags()
87

88
	// groupStatus format is the list of flags as a unique ID,
89
	// then a map of each flag name and whether it is set or not.
90
	groupStatus := map[string]map[string]bool{}
91
	oneRequiredGroupStatus := map[string]map[string]bool{}
92
	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
93
	flags.VisitAll(func(pflag *flag.Flag) {
94
		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
95
		processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
96
		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
97
	})
98

99
	if err := validateRequiredFlagGroups(groupStatus); err != nil {
100
		return err
101
	}
102
	if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
103
		return err
104
	}
105
	if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
106
		return err
107
	}
108
	return nil
109
}
110

111
func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
112
	for _, fname := range flagnames {
113
		f := fs.Lookup(fname)
114
		if f == nil {
115
			return false
116
		}
117
	}
118
	return true
119
}
120

121
func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
122
	groupInfo, found := pflag.Annotations[annotation]
123
	if found {
124
		for _, group := range groupInfo {
125
			if groupStatus[group] == nil {
126
				flagnames := strings.Split(group, " ")
127

128
				// Only consider this flag group at all if all the flags are defined.
129
				if !hasAllFlags(flags, flagnames...) {
130
					continue
131
				}
132

133
				groupStatus[group] = map[string]bool{}
134
				for _, name := range flagnames {
135
					groupStatus[group][name] = false
136
				}
137
			}
138

139
			groupStatus[group][pflag.Name] = pflag.Changed
140
		}
141
	}
142
}
143

144
func validateRequiredFlagGroups(data map[string]map[string]bool) error {
145
	keys := sortedKeys(data)
146
	for _, flagList := range keys {
147
		flagnameAndStatus := data[flagList]
148

149
		unset := []string{}
150
		for flagname, isSet := range flagnameAndStatus {
151
			if !isSet {
152
				unset = append(unset, flagname)
153
			}
154
		}
155
		if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
156
			continue
157
		}
158

159
		// Sort values, so they can be tested/scripted against consistently.
160
		sort.Strings(unset)
161
		return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
162
	}
163

164
	return nil
165
}
166

167
func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
168
	keys := sortedKeys(data)
169
	for _, flagList := range keys {
170
		flagnameAndStatus := data[flagList]
171
		var set []string
172
		for flagname, isSet := range flagnameAndStatus {
173
			if isSet {
174
				set = append(set, flagname)
175
			}
176
		}
177
		if len(set) >= 1 {
178
			continue
179
		}
180

181
		// Sort values, so they can be tested/scripted against consistently.
182
		sort.Strings(set)
183
		return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
184
	}
185
	return nil
186
}
187

188
func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
189
	keys := sortedKeys(data)
190
	for _, flagList := range keys {
191
		flagnameAndStatus := data[flagList]
192
		var set []string
193
		for flagname, isSet := range flagnameAndStatus {
194
			if isSet {
195
				set = append(set, flagname)
196
			}
197
		}
198
		if len(set) == 0 || len(set) == 1 {
199
			continue
200
		}
201

202
		// Sort values, so they can be tested/scripted against consistently.
203
		sort.Strings(set)
204
		return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
205
	}
206
	return nil
207
}
208

209
func sortedKeys(m map[string]map[string]bool) []string {
210
	keys := make([]string, len(m))
211
	i := 0
212
	for k := range m {
213
		keys[i] = k
214
		i++
215
	}
216
	sort.Strings(keys)
217
	return keys
218
}
219

220
// enforceFlagGroupsForCompletion will do the following:
221
// - when a flag in a group is present, other flags in the group will be marked required
222
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
223
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
224
// This allows the standard completion logic to behave appropriately for flag groups
225
func (c *Command) enforceFlagGroupsForCompletion() {
226
	if c.DisableFlagParsing {
227
		return
228
	}
229

230
	flags := c.Flags()
231
	groupStatus := map[string]map[string]bool{}
232
	oneRequiredGroupStatus := map[string]map[string]bool{}
233
	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
234
	c.Flags().VisitAll(func(pflag *flag.Flag) {
235
		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
236
		processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
237
		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
238
	})
239

240
	// If a flag that is part of a group is present, we make all the other flags
241
	// of that group required so that the shell completion suggests them automatically
242
	for flagList, flagnameAndStatus := range groupStatus {
243
		for _, isSet := range flagnameAndStatus {
244
			if isSet {
245
				// One of the flags of the group is set, mark the other ones as required
246
				for _, fName := range strings.Split(flagList, " ") {
247
					_ = c.MarkFlagRequired(fName)
248
				}
249
			}
250
		}
251
	}
252

253
	// If none of the flags of a one-required group are present, we make all the flags
254
	// of that group required so that the shell completion suggests them automatically
255
	for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
256
		set := 0
257

258
		for _, isSet := range flagnameAndStatus {
259
			if isSet {
260
				set++
261
			}
262
		}
263

264
		// None of the flags of the group are set, mark all flags in the group
265
		// as required
266
		if set == 0 {
267
			for _, fName := range strings.Split(flagList, " ") {
268
				_ = c.MarkFlagRequired(fName)
269
			}
270
		}
271
	}
272

273
	// If a flag that is mutually exclusive to others is present, we hide the other
274
	// flags of that group so the shell completion does not suggest them
275
	for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
276
		for flagName, isSet := range flagnameAndStatus {
277
			if isSet {
278
				// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
279
				// Don't mark the flag that is already set as hidden because it may be an
280
				// array or slice flag and therefore must continue being suggested
281
				for _, fName := range strings.Split(flagList, " ") {
282
					if fName != flagName {
283
						flag := c.Flags().Lookup(fName)
284
						flag.Hidden = true
285
					}
286
				}
287
			}
288
		}
289
	}
290
}
291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.