podman
290 строк · 9.2 Кб
1// Copyright 2013-2023 The Cobra Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package cobra
16
17import (
18"fmt"
19"sort"
20"strings"
21
22flag "github.com/spf13/pflag"
23)
24
25const (
26requiredAsGroup = "cobra_annotation_required_if_others_set"
27oneRequired = "cobra_annotation_one_required"
28mutuallyExclusive = "cobra_annotation_mutually_exclusive"
29)
30
31// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
32// if the command is invoked with a subset (but not all) of the given flags.
33func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
34c.mergePersistentFlags()
35for _, v := range flagNames {
36f := c.Flags().Lookup(v)
37if f == nil {
38panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
39}
40if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
41// Only errs if the flag isn't found.
42panic(err)
43}
44}
45}
46
47// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
48// if the command is invoked without at least one flag from the given set of flags.
49func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
50c.mergePersistentFlags()
51for _, v := range flagNames {
52f := c.Flags().Lookup(v)
53if f == nil {
54panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
55}
56if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
57// Only errs if the flag isn't found.
58panic(err)
59}
60}
61}
62
63// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
64// if the command is invoked with more than one flag from the given set of flags.
65func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
66c.mergePersistentFlags()
67for _, v := range flagNames {
68f := c.Flags().Lookup(v)
69if f == nil {
70panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
71}
72// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
73if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
74panic(err)
75}
76}
77}
78
79// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
80// first error encountered.
81func (c *Command) ValidateFlagGroups() error {
82if c.DisableFlagParsing {
83return nil
84}
85
86flags := c.Flags()
87
88// groupStatus format is the list of flags as a unique ID,
89// then a map of each flag name and whether it is set or not.
90groupStatus := map[string]map[string]bool{}
91oneRequiredGroupStatus := map[string]map[string]bool{}
92mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
93flags.VisitAll(func(pflag *flag.Flag) {
94processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
95processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
96processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
97})
98
99if err := validateRequiredFlagGroups(groupStatus); err != nil {
100return err
101}
102if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
103return err
104}
105if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
106return err
107}
108return nil
109}
110
111func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
112for _, fname := range flagnames {
113f := fs.Lookup(fname)
114if f == nil {
115return false
116}
117}
118return true
119}
120
121func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
122groupInfo, found := pflag.Annotations[annotation]
123if found {
124for _, group := range groupInfo {
125if groupStatus[group] == nil {
126flagnames := strings.Split(group, " ")
127
128// Only consider this flag group at all if all the flags are defined.
129if !hasAllFlags(flags, flagnames...) {
130continue
131}
132
133groupStatus[group] = map[string]bool{}
134for _, name := range flagnames {
135groupStatus[group][name] = false
136}
137}
138
139groupStatus[group][pflag.Name] = pflag.Changed
140}
141}
142}
143
144func validateRequiredFlagGroups(data map[string]map[string]bool) error {
145keys := sortedKeys(data)
146for _, flagList := range keys {
147flagnameAndStatus := data[flagList]
148
149unset := []string{}
150for flagname, isSet := range flagnameAndStatus {
151if !isSet {
152unset = append(unset, flagname)
153}
154}
155if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
156continue
157}
158
159// Sort values, so they can be tested/scripted against consistently.
160sort.Strings(unset)
161return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
162}
163
164return nil
165}
166
167func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
168keys := sortedKeys(data)
169for _, flagList := range keys {
170flagnameAndStatus := data[flagList]
171var set []string
172for flagname, isSet := range flagnameAndStatus {
173if isSet {
174set = append(set, flagname)
175}
176}
177if len(set) >= 1 {
178continue
179}
180
181// Sort values, so they can be tested/scripted against consistently.
182sort.Strings(set)
183return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
184}
185return nil
186}
187
188func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
189keys := sortedKeys(data)
190for _, flagList := range keys {
191flagnameAndStatus := data[flagList]
192var set []string
193for flagname, isSet := range flagnameAndStatus {
194if isSet {
195set = append(set, flagname)
196}
197}
198if len(set) == 0 || len(set) == 1 {
199continue
200}
201
202// Sort values, so they can be tested/scripted against consistently.
203sort.Strings(set)
204return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
205}
206return nil
207}
208
209func sortedKeys(m map[string]map[string]bool) []string {
210keys := make([]string, len(m))
211i := 0
212for k := range m {
213keys[i] = k
214i++
215}
216sort.Strings(keys)
217return keys
218}
219
220// enforceFlagGroupsForCompletion will do the following:
221// - when a flag in a group is present, other flags in the group will be marked required
222// - when none of the flags in a one-required group are present, all flags in the group will be marked required
223// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
224// This allows the standard completion logic to behave appropriately for flag groups
225func (c *Command) enforceFlagGroupsForCompletion() {
226if c.DisableFlagParsing {
227return
228}
229
230flags := c.Flags()
231groupStatus := map[string]map[string]bool{}
232oneRequiredGroupStatus := map[string]map[string]bool{}
233mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
234c.Flags().VisitAll(func(pflag *flag.Flag) {
235processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
236processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
237processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
238})
239
240// If a flag that is part of a group is present, we make all the other flags
241// of that group required so that the shell completion suggests them automatically
242for flagList, flagnameAndStatus := range groupStatus {
243for _, isSet := range flagnameAndStatus {
244if isSet {
245// One of the flags of the group is set, mark the other ones as required
246for _, fName := range strings.Split(flagList, " ") {
247_ = c.MarkFlagRequired(fName)
248}
249}
250}
251}
252
253// If none of the flags of a one-required group are present, we make all the flags
254// of that group required so that the shell completion suggests them automatically
255for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
256set := 0
257
258for _, isSet := range flagnameAndStatus {
259if isSet {
260set++
261}
262}
263
264// None of the flags of the group are set, mark all flags in the group
265// as required
266if set == 0 {
267for _, fName := range strings.Split(flagList, " ") {
268_ = c.MarkFlagRequired(fName)
269}
270}
271}
272
273// If a flag that is mutually exclusive to others is present, we hide the other
274// flags of that group so the shell completion does not suggest them
275for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
276for flagName, isSet := range flagnameAndStatus {
277if isSet {
278// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
279// Don't mark the flag that is already set as hidden because it may be an
280// array or slice flag and therefore must continue being suggested
281for _, fName := range strings.Split(flagList, " ") {
282if fName != flagName {
283flag := c.Flags().Lookup(fName)
284flag.Hidden = true
285}
286}
287}
288}
289}
290}
291