go-tg-screenshot-bot
212 строк · 5.2 Кб
1//+build !windows,!solaris
2
3package dbus
4
5import (
6"bytes"
7"encoding/binary"
8"errors"
9"io"
10"net"
11"syscall"
12)
13
14type oobReader struct {
15conn *net.UnixConn
16oob []byte
17buf [4096]byte
18}
19
20func (o *oobReader) Read(b []byte) (n int, err error) {
21n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
22if err != nil {
23return n, err
24}
25if flags&syscall.MSG_CTRUNC != 0 {
26return n, errors.New("dbus: control data truncated (too many fds received)")
27}
28o.oob = append(o.oob, o.buf[:oobn]...)
29return n, nil
30}
31
32type unixTransport struct {
33*net.UnixConn
34rdr *oobReader
35hasUnixFDs bool
36}
37
38func newUnixTransport(keys string) (transport, error) {
39var err error
40
41t := new(unixTransport)
42abstract := getKey(keys, "abstract")
43path := getKey(keys, "path")
44switch {
45case abstract == "" && path == "":
46return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
47case abstract != "" && path == "":
48t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
49if err != nil {
50return nil, err
51}
52return t, nil
53case abstract == "" && path != "":
54t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
55if err != nil {
56return nil, err
57}
58return t, nil
59default:
60return nil, errors.New("dbus: invalid address (both path and abstract set)")
61}
62}
63
64func init() {
65transports["unix"] = newUnixTransport
66}
67
68func (t *unixTransport) EnableUnixFDs() {
69t.hasUnixFDs = true
70}
71
72func (t *unixTransport) ReadMessage() (*Message, error) {
73var (
74blen, hlen uint32
75csheader [16]byte
76headers []header
77order binary.ByteOrder
78unixfds uint32
79)
80// To be sure that all bytes of out-of-band data are read, we use a special
81// reader that uses ReadUnix on the underlying connection instead of Read
82// and gathers the out-of-band data in a buffer.
83if t.rdr == nil {
84t.rdr = &oobReader{conn: t.UnixConn}
85} else {
86t.rdr.oob = nil
87}
88
89// read the first 16 bytes (the part of the header that has a constant size),
90// from which we can figure out the length of the rest of the message
91if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
92return nil, err
93}
94switch csheader[0] {
95case 'l':
96order = binary.LittleEndian
97case 'B':
98order = binary.BigEndian
99default:
100return nil, InvalidMessageError("invalid byte order")
101}
102// csheader[4:8] -> length of message body, csheader[12:16] -> length of
103// header fields (without alignment)
104binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
105binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
106if hlen%8 != 0 {
107hlen += 8 - (hlen % 8)
108}
109
110// decode headers and look for unix fds
111headerdata := make([]byte, hlen+4)
112copy(headerdata, csheader[12:])
113if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
114return nil, err
115}
116dec := newDecoder(bytes.NewBuffer(headerdata), order, make([]int, 0))
117dec.pos = 12
118vs, err := dec.Decode(Signature{"a(yv)"})
119if err != nil {
120return nil, err
121}
122Store(vs, &headers)
123for _, v := range headers {
124if v.Field == byte(FieldUnixFDs) {
125unixfds, _ = v.Variant.value.(uint32)
126}
127}
128all := make([]byte, 16+hlen+blen)
129copy(all, csheader[:])
130copy(all[16:], headerdata[4:])
131if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {
132return nil, err
133}
134if unixfds != 0 {
135if !t.hasUnixFDs {
136return nil, errors.New("dbus: got unix fds on unsupported transport")
137}
138// read the fds from the OOB data
139scms, err := syscall.ParseSocketControlMessage(t.rdr.oob)
140if err != nil {
141return nil, err
142}
143if len(scms) != 1 {
144return nil, errors.New("dbus: received more than one socket control message")
145}
146fds, err := syscall.ParseUnixRights(&scms[0])
147if err != nil {
148return nil, err
149}
150msg, err := DecodeMessageWithFDs(bytes.NewBuffer(all), fds)
151if err != nil {
152return nil, err
153}
154// substitute the values in the message body (which are indices for the
155// array receiver via OOB) with the actual values
156for i, v := range msg.Body {
157switch index := v.(type) {
158case UnixFDIndex:
159if uint32(index) >= unixfds {
160return nil, InvalidMessageError("invalid index for unix fd")
161}
162msg.Body[i] = UnixFD(fds[index])
163case []UnixFDIndex:
164fdArray := make([]UnixFD, len(index))
165for k, j := range index {
166if uint32(j) >= unixfds {
167return nil, InvalidMessageError("invalid index for unix fd")
168}
169fdArray[k] = UnixFD(fds[j])
170}
171msg.Body[i] = fdArray
172}
173}
174return msg, nil
175}
176return DecodeMessage(bytes.NewBuffer(all))
177}
178
179func (t *unixTransport) SendMessage(msg *Message) error {
180fdcnt, err := msg.CountFds()
181if err != nil {
182return err
183}
184if fdcnt != 0 {
185if !t.hasUnixFDs {
186return errors.New("dbus: unix fd passing not enabled")
187}
188msg.Headers[FieldUnixFDs] = MakeVariant(uint32(fdcnt))
189buf := new(bytes.Buffer)
190fds, err := msg.EncodeToWithFDs(buf, nativeEndian)
191if err != nil {
192return err
193}
194oob := syscall.UnixRights(fds...)
195n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
196if err != nil {
197return err
198}
199if n != buf.Len() || oobn != len(oob) {
200return io.ErrShortWrite
201}
202} else {
203if err := msg.EncodeTo(t, nativeEndian); err != nil {
204return err
205}
206}
207return nil
208}
209
210func (t *unixTransport) SupportsUnixFDs() bool {
211return true
212}
213