podman
413 строк · 10.1 Кб
1// Copyright 2018 The go-libvirt Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package libvirt16
17import (18"bytes"19"errors"20"fmt"21"io"22"reflect"23"strings"24"sync/atomic"25
26"github.com/digitalocean/go-libvirt/internal/constants"27"github.com/digitalocean/go-libvirt/internal/event"28xdr "github.com/digitalocean/go-libvirt/internal/go-xdr/xdr2"29"github.com/digitalocean/go-libvirt/socket"30)
31
32// ErrUnsupported is returned if a procedure is not supported by libvirt
33var ErrUnsupported = errors.New("unsupported procedure requested")34
35// internal rpc response
36type response struct {37Payload []byte38Status uint3239}
40
41// Error reponse from libvirt
42type Error struct {43Code uint3244Message string45}
46
47func (e Error) Error() string {48return e.Message49}
50
51// checkError is used to check whether an error is a libvirtError, and if it is,
52// whether its error code matches the one passed in. It will return false if
53// these conditions are not met.
54func checkError(err error, expectedError ErrorNumber) bool {55for err != nil {56e, ok := err.(Error)57if ok {58return e.Code == uint32(expectedError)59}60err = errors.Unwrap(err)61}62return false63}
64
65// IsNotFound detects libvirt's ERR_NO_DOMAIN.
66func IsNotFound(err error) bool {67return checkError(err, ErrNoDomain)68}
69
70// callback sends RPC responses to respective callers.
71func (l *Libvirt) callback(id int32, res response) {72l.cmux.Lock()73defer l.cmux.Unlock()74
75c, ok := l.callbacks[id]76if !ok {77return78}79
80c <- res81}
82
83// Route sends incoming packets to their listeners.
84func (l *Libvirt) Route(h *socket.Header, buf []byte) {85// Route events to their respective listener86var event event.Event87
88switch {89case h.Program == constants.QEMUProgram && h.Procedure == constants.QEMUProcDomainMonitorEvent:90event = &DomainEvent{}91case h.Program == constants.Program && h.Procedure == constants.ProcDomainEventCallbackLifecycle:92event = &DomainEventCallbackLifecycleMsg{}93}94
95if event != nil {96err := eventDecoder(buf, event)97if err != nil { // event was malformed, drop.98return99}100
101l.stream(event)102return103}104
105// send response to caller106l.callback(h.Serial, response{Payload: buf, Status: h.Status})107}
108
109// serial provides atomic access to the next sequential request serial number.
110func (l *Libvirt) serial() int32 {111return atomic.AddInt32(&l.s, 1)112}
113
114// stream decodes and relays domain events to their respective listener.
115func (l *Libvirt) stream(e event.Event) {116l.emux.RLock()117defer l.emux.RUnlock()118
119q, ok := l.events[e.GetCallbackID()]120if !ok {121return122}123
124q.Push(e)125}
126
127// addStream configures the routing for an event stream.
128func (l *Libvirt) addStream(s *event.Stream) {129l.emux.Lock()130defer l.emux.Unlock()131
132l.events[s.CallbackID] = s133}
134
135// removeStream deletes an event stream. The caller should first notify libvirt
136// to stop sending events for this stream. Subsequent calls to removeStream are
137// idempotent and return nil.
138func (l *Libvirt) removeStream(id int32) error {139l.emux.Lock()140defer l.emux.Unlock()141
142// if the event is already removed, just return nil143q, ok := l.events[id]144if ok {145delete(l.events, id)146q.Shutdown()147}148
149return nil150}
151
152// removeAllStreams deletes all event streams. This is meant to be used to
153// clean up only once the underlying connection to libvirt is disconnected and
154// thus does not attempt to notify libvirt to stop sending events.
155func (l *Libvirt) removeAllStreams() {156l.emux.Lock()157defer l.emux.Unlock()158
159for _, ev := range l.events {160ev.Shutdown()161delete(l.events, ev.CallbackID)162}163}
164
165// register configures a method response callback
166func (l *Libvirt) register(id int32, c chan response) {167l.cmux.Lock()168defer l.cmux.Unlock()169
170l.callbacks[id] = c171}
172
173// deregister destroys a method response callback. It is the responsibility of
174// the caller to manage locking (l.cmux) during this call.
175func (l *Libvirt) deregister(id int32) {176_, ok := l.callbacks[id]177if !ok {178return179}180
181close(l.callbacks[id])182delete(l.callbacks, id)183}
184
185// deregisterAll closes all waiting callback channels. This is used to clean up
186// if the connection to libvirt is lost. Callers waiting for responses will
187// return an error when the response channel is closed, rather than just
188// hanging.
189func (l *Libvirt) deregisterAll() {190l.cmux.Lock()191defer l.cmux.Unlock()192
193for id := range l.callbacks {194l.deregister(id)195}196}
197
198// request performs a libvirt RPC request.
199// returns response returned by server.
200// if response is not OK, decodes error from it and returns it.
201func (l *Libvirt) request(proc uint32, program uint32, payload []byte) (response, error) {202return l.requestStream(proc, program, payload, nil, nil)203}
204
205// requestStream performs a libvirt RPC request. The `out` and `in` parameters
206// are optional, and should be nil when RPC endpoints don't return a stream.
207func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte,208out io.Reader, in io.Writer) (response, error) {209serial := l.serial()210c := make(chan response)211
212l.register(serial, c)213defer func() {214l.cmux.Lock()215defer l.cmux.Unlock()216
217l.deregister(serial)218}()219
220err := l.socket.SendPacket(serial, proc, program, payload, socket.Call,221socket.StatusOK)222if err != nil {223return response{}, err224}225
226resp, err := l.getResponse(c)227if err != nil {228return resp, err229}230
231if out != nil {232abort := make(chan bool)233outErr := make(chan error)234go func() {235outErr <- l.socket.SendStream(serial, proc, program, out, abort)236}()237
238// Even without incoming stream server sends confirmation once all data is received239resp, err = l.processIncomingStream(c, in)240if err != nil {241abort <- true242return resp, err243}244
245err = <-outErr246if err != nil {247return response{}, err248}249}250
251switch in {252case nil:253return resp, nil254default:255return l.processIncomingStream(c, in)256}257}
258
259// processIncomingStream is called once we've successfully sent a request to
260// libvirt. It writes the responses back to the stream passed by the caller
261// until libvirt sends a packet with statusOK or an error.
262func (l *Libvirt) processIncomingStream(c chan response, inStream io.Writer) (response, error) {263for {264resp, err := l.getResponse(c)265if err != nil {266return resp, err267}268
269// StatusOK indicates end of stream270if resp.Status == socket.StatusOK {271return resp, nil272}273
274// FIXME: this smells.275// StatusError is handled in getResponse, so this must be StatusContinue276// StatusContinue is only valid here for stream packets277// libvirtd breaks protocol and returns StatusContinue with an278// empty response Payload when the stream finishes279if len(resp.Payload) == 0 {280return resp, nil281}282if inStream != nil {283_, err = inStream.Write(resp.Payload)284if err != nil {285return response{}, err286}287}288}289}
290
291func (l *Libvirt) getResponse(c chan response) (response, error) {292resp := <-c293if resp.Status == socket.StatusError {294return resp, decodeError(resp.Payload)295}296
297return resp, nil298}
299
300// encode XDR encodes the provided data.
301func encode(data interface{}) ([]byte, error) {302var buf bytes.Buffer303_, err := xdr.Marshal(&buf, data)304
305return buf.Bytes(), err306}
307
308// decodeError extracts an error message from the provider buffer.
309func decodeError(buf []byte) error {310dec := xdr.NewDecoder(bytes.NewReader(buf))311
312e := struct {313Code uint32314DomainID uint32315Padding uint8316Message string317Level uint32318}{}319_, err := dec.Decode(&e)320if err != nil {321return err322}323
324if strings.Contains(e.Message, "unknown procedure") {325return ErrUnsupported326}327
328// if libvirt returns ERR_OK, ignore the error329if ErrorNumber(e.Code) == ErrOk {330return nil331}332
333return Error{Code: uint32(e.Code), Message: e.Message}334}
335
336// eventDecoder decodes an event from a xdr buffer.
337func eventDecoder(buf []byte, e interface{}) error {338dec := xdr.NewDecoder(bytes.NewReader(buf))339_, err := dec.Decode(e)340return err341}
342
343type typedParamDecoder struct{}344
345// Decode decodes a TypedParam. These are part of the libvirt spec, and not xdr
346// proper. TypedParams contain a name, which is called Field for some reason,
347// and a Value, which itself has a "discriminant" - an integer enum encoding the
348// actual type, and a value, the length of which varies based on the actual
349// type.
350func (tpd typedParamDecoder) Decode(d *xdr.Decoder, v reflect.Value) (int, error) {351// Get the name of the typed param first352name, n, err := d.DecodeString()353if err != nil {354return n, err355}356val, n2, err := tpd.decodeTypedParamValue(d)357n += n2358if err != nil {359return n, err360}361tp := &TypedParam{Field: name, Value: *val}362v.Set(reflect.ValueOf(*tp))363
364return n, nil365}
366
367// decodeTypedParamValue decodes the Value part of a TypedParam.
368func (typedParamDecoder) decodeTypedParamValue(d *xdr.Decoder) (*TypedParamValue, int, error) {369// All TypedParamValues begin with a uint32 discriminant that tells us what370// type they are.371discriminant, n, err := d.DecodeUint()372if err != nil {373return nil, n, err374}375var n2 int376var tpv *TypedParamValue377switch discriminant {378case 1:379var val int32380n2, err = d.Decode(&val)381tpv = &TypedParamValue{D: discriminant, I: val}382case 2:383var val uint32384n2, err = d.Decode(&val)385tpv = &TypedParamValue{D: discriminant, I: val}386case 3:387var val int64388n2, err = d.Decode(&val)389tpv = &TypedParamValue{D: discriminant, I: val}390case 4:391var val uint64392n2, err = d.Decode(&val)393tpv = &TypedParamValue{D: discriminant, I: val}394case 5:395var val float64396n2, err = d.Decode(&val)397tpv = &TypedParamValue{D: discriminant, I: val}398case 6:399var val int32400n2, err = d.Decode(&val)401tpv = &TypedParamValue{D: discriminant, I: val}402case 7:403var val string404n2, err = d.Decode(&val)405tpv = &TypedParamValue{D: discriminant, I: val}406
407default:408err = fmt.Errorf("invalid parameter type %v", discriminant)409}410n += n2411
412return tpv, n, err413}
414