kubelatte-ce
Форк от sbertech/kubelatte-ce
241 строка · 6.2 Кб
1package util
2
3import (
4"crypto/tls"
5"crypto/x509"
6"fmt"
7"sync"
8"time"
9
10"github.com/spf13/afero"
11)
12
13//TODO Make the Reloader its own thing and have a special case for the Cert one?
14
15// A tool to reload certificates automatically
16type CertificateReloader interface {
17Start() error // Start the monitoring of the key file
18Stop() chan struct{} // Stop the monitoring
19IsRunning() bool // Returns true if the reloader is running
20GetCertificate() (*tls.Certificate, error) // Returns the latest certs available and errors if latest cert has expired
21}
22
23type CertificatePKIReloader struct {
24refreshInterval time.Duration
25lock sync.RWMutex
26stopCh chan struct{}
27stoppedCh chan struct{}
28started bool
29lastModTime time.Time
30certExpiry time.Time
31fs afero.Fs
32certFilename string
33keyFilename string
34cert *tls.Certificate
35errHandler func(error)
36}
37
38// FileError indicates there was a problem inspecting or reading the files
39// being monitored.
40type FileError struct {
41error
42}
43
44// TLSError indicates there was a problem converting the contents of the
45// monitored files into x509 certificate/key pair.
46type TLSError struct {
47error
48}
49
50// Creates a CertificateReloader based on the files and afero FS.
51func NewCertificatePKIReloaderFull(fs afero.Fs, certFilename, keyFilename string, refreshInterval time.Duration) *CertificatePKIReloader {
52return newCertificatePKIReloaderFull(
53fs,
54certFilename,
55keyFilename,
56refreshInterval,
57nil,
58)
59}
60
61// Creates a CertificateReloader based on the files and afero FS.
62// Calls the given error handler when there are problems reading the given
63// files. The error passed to the handler will be a FileError, TLSError, or
64// error.
65// If errHandler is nil, the default behavior is to do nothing on error.
66func NewCertificatePKIReloaderFullWithErrHandler(fs afero.Fs, certFilename, keyFilename string, refreshInterval time.Duration, errHandler func(error)) *CertificatePKIReloader {
67return newCertificatePKIReloaderFull(
68fs,
69certFilename,
70keyFilename,
71refreshInterval,
72errHandler,
73)
74}
75
76// A simplified version of NewCertificatePKIReloaderFull where the fs is the OS fs by default
77func NewCertificatePKIReloader(certFilename, keyFilename string, refreshInterval time.Duration) *CertificatePKIReloader {
78return NewCertificatePKIReloaderFull(
79afero.NewOsFs(),
80certFilename,
81keyFilename,
82refreshInterval)
83}
84
85// A simplified version of NewCertificatePKIReloaderFullWithErrHandler where the
86// fs is the OS fs by default.
87// Calls the given error handler when there are problems reading the given
88// files. The error passed to the handler will be a FileError, TLSError, or
89// error.
90// If errHandler is nil, the default behavior is to do nothing on error.
91func NewCertificatePKIReloaderWithErrHandler(certFilename, keyFilename string, refreshInterval time.Duration, errHandler func(error)) *CertificatePKIReloader {
92return newCertificatePKIReloaderFull(
93afero.NewOsFs(),
94certFilename,
95keyFilename,
96refreshInterval,
97errHandler,
98)
99}
100
101func newCertificatePKIReloaderFull(fs afero.Fs, certFilename, keyFilename string, refreshInterval time.Duration, errHandler func(error)) *CertificatePKIReloader {
102if errHandler == nil {
103errHandler = func(_ error) { /* Do nothing */ }
104}
105
106return &CertificatePKIReloader{
107fs: fs,
108certFilename: certFilename,
109keyFilename: keyFilename,
110refreshInterval: refreshInterval,
111started: false,
112errHandler: errHandler,
113}
114}
115
116func (r *CertificatePKIReloader) Start() error {
117if r == nil {
118panic("Calling Start on uninit CertificatePKIReloader")
119}
120if !r.started {
121r.runRefresh()
122if _, err := r.GetCertificate(); err != nil {
123return err
124}
125r.stopCh = make(chan struct{})
126r.stoppedCh = make(chan struct{})
127r.started = true
128go r.runRefreshLoop()
129}
130
131return nil
132}
133
134func (r *CertificatePKIReloader) Stop() chan struct{} {
135if r == nil {
136panic("Calling Start on uninit CertificatePKIReloader")
137}
138r.lock.Lock()
139defer r.lock.Unlock()
140
141if !r.started {
142stoppedCh := make(chan struct{})
143close(stoppedCh)
144return stoppedCh
145}
146
147close(r.stopCh)
148r.started = false
149return r.stoppedCh
150}
151
152func (r *CertificatePKIReloader) IsRunning() bool {
153if r == nil {
154return false
155}
156r.lock.RLock()
157defer r.lock.RUnlock()
158
159return r.started
160}
161
162func (r *CertificatePKIReloader) GetCertificate() (*tls.Certificate, error) {
163if r == nil {
164panic("Calling Start on uninit CertificatePKIReloader")
165}
166r.lock.RLock()
167defer r.lock.RUnlock()
168// return error if certificate in cache has expired
169if r.certExpiry.Before(time.Now()) {
170return nil, fmt.Errorf("certificate expired at %v", r.certExpiry)
171}
172return r.cert, nil
173}
174
175func readCert(fs afero.Fs, certFilename, keyFilename string) (*tls.Certificate, error) {
176certPEMBlock, err := afero.ReadFile(fs, certFilename)
177if err != nil {
178return &tls.Certificate{}, FileError{error: err}
179}
180
181keyPEMBlock, err := afero.ReadFile(fs, keyFilename)
182if err != nil {
183return &tls.Certificate{}, FileError{error: err}
184}
185
186cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
187if err != nil {
188return &tls.Certificate{}, TLSError{error: err}
189}
190return &cert, nil
191}
192
193func readModTime(fs afero.Fs, filename string) (time.Time, error) {
194f, err := fs.Stat(filename)
195if err != nil {
196return time.Time{}, nil
197}
198
199return f.ModTime(), nil
200}
201
202func (r *CertificatePKIReloader) runRefresh() {
203modTime, err := readModTime(r.fs, r.keyFilename)
204if err != nil {
205r.errHandler(err)
206return
207}
208
209if r.lastModTime.Before(modTime) {
210cert, err := readCert(r.fs, r.certFilename, r.keyFilename)
211if err != nil {
212r.errHandler(err)
213return
214}
215clientCert, err := x509.ParseCertificate(cert.Certificate[0])
216if err != nil {
217r.errHandler(err)
218return
219}
220r.lock.Lock()
221// cert, lastModTime, certExpiry are not updated in case of errors reading the cert
222r.lastModTime = modTime
223r.cert = cert
224r.certExpiry = clientCert.NotAfter
225r.lock.Unlock()
226}
227}
228
229func (r *CertificatePKIReloader) runRefreshLoop() {
230defer close(r.stoppedCh)
231
232ticker := time.NewTicker(r.refreshInterval)
233for {
234select {
235case <-ticker.C:
236r.runRefresh()
237case <-r.stopCh:
238return
239}
240}
241}
242