kraken
1// Copyright (c) 2016-2019 Uber Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14package dedup15
16import (17"errors"18"sync"19"time"20
21"github.com/andres-erbsen/clock"22)
23
24// RequestCacheConfig defines RequestCache configuration.
25type RequestCacheConfig struct {26NotFoundTTL time.Duration `yaml:"not_found_ttl"`27ErrorTTL time.Duration `yaml:"error_ttl"`28CleanupInterval time.Duration `yaml:"cleanup_interval"`29NumWorkers int `yaml:"num_workers"`30BusyTimeout time.Duration `yaml:"busy_timeout"`31}
32
33func (c *RequestCacheConfig) applyDefaults() {34// TODO(codyg): If the cached error TTL is lower than the interval in which35// clients are polling a 202 endpoint, then it is possible that the client36// will never hit the actual error because it expires in between requests.37if c.NotFoundTTL == 0 {38c.NotFoundTTL = 15 * time.Second39}40if c.ErrorTTL == 0 {41c.ErrorTTL = 15 * time.Second42}43if c.CleanupInterval == 0 {44c.CleanupInterval = 5 * time.Second45}46if c.NumWorkers == 0 {47c.NumWorkers = 1000048}49if c.BusyTimeout == 0 {50c.BusyTimeout = 5 * time.Second51}52}
53
54// RequestCache errors.
55var (56ErrRequestPending = errors.New("request pending")57ErrWorkersBusy = errors.New("no workers available to handle request")58)
59
60type cachedError struct {61err error62expiresAt time.Time63}
64
65func (e *cachedError) expired(now time.Time) bool {66return now.After(e.expiresAt)67}
68
69// Request defines functions which encapsulate a request.
70type Request func() error71
72// ErrorMatcher defines functions which RequestCache uses to detect user defined
73// errors.
74type ErrorMatcher func(error) bool75
76// RequestCache tracks pending requests and caches errors for configurable TTLs.
77// It is used to prevent request duplication and DDOS-ing external components.
78// Each request is represented by an arbitrary id string determined by the user.
79type RequestCache struct {80config RequestCacheConfig
81clk clock.Clock82
83mu sync.Mutex // Protects access to the following fields:84pending map[string]bool85errors map[string]*cachedError86lastClean time.Time87isNotFound ErrorMatcher
88
89numWorkers chan struct{}90}
91
92// NewRequestCache creates a new RequestCache.
93func NewRequestCache(config RequestCacheConfig, clk clock.Clock) *RequestCache {94config.applyDefaults()95return &RequestCache{96config: config,97clk: clk,98pending: make(map[string]bool),99errors: make(map[string]*cachedError),100lastClean: clk.Now(),101isNotFound: func(error) bool { return false },102numWorkers: make(chan struct{}, config.NumWorkers),103}104}
105
106// SetNotFound sets the ErrorMatcher for activating the configured NotFoundTTL
107// for errors returned by Request functions.
108func (c *RequestCache) SetNotFound(m ErrorMatcher) {109c.mu.Lock()110defer c.mu.Unlock()111
112c.isNotFound = m113}
114
115// Start concurrently runs r under the given id. Any error returned by r will be
116// cached for the configured TTL. If there is already a function executing under
117// id, Start returns ErrRequestPending. If there are no available workers to run
118// r, Start returns ErrWorkersBusy.
119func (c *RequestCache) Start(id string, r Request) error {120if err := c.reserve(id); err != nil {121return err122}123if err := c.reserveWorker(); err != nil {124c.release(id)125return err126}127go func() {128defer c.releaseWorker()129c.run(id, r)130}()131return nil132}
133
134func (c *RequestCache) reserve(id string) error {135c.mu.Lock()136defer c.mu.Unlock()137
138// Periodically remove expired errors.139if c.clk.Now().Sub(c.lastClean) > c.config.CleanupInterval {140for id, cerr := range c.errors {141if cerr.expired(c.clk.Now()) {142delete(c.errors, id)143}144}145c.lastClean = c.clk.Now()146}147
148if c.pending[id] {149return ErrRequestPending150}151if cerr, ok := c.errors[id]; ok && !cerr.expired(c.clk.Now()) {152return cerr.err153}154
155c.pending[id] = true156
157return nil158}
159
160func (c *RequestCache) run(id string, r Request) {161if err := r(); err != nil {162c.error(id, err)163return164}165c.release(id)166}
167
168func (c *RequestCache) release(id string) {169c.mu.Lock()170defer c.mu.Unlock()171
172delete(c.pending, id)173}
174
175func (c *RequestCache) error(id string, err error) {176c.mu.Lock()177defer c.mu.Unlock()178
179var ttl time.Duration180if c.isNotFound(err) {181ttl = c.config.NotFoundTTL182} else {183ttl = c.config.ErrorTTL184}185delete(c.pending, id)186c.errors[id] = &cachedError{err: err, expiresAt: c.clk.Now().Add(ttl)}187}
188
189func (c *RequestCache) reserveWorker() error {190select {191case c.numWorkers <- struct{}{}:192return nil193case <-c.clk.After(c.config.BusyTimeout):194return ErrWorkersBusy195}196}
197
198func (c *RequestCache) releaseWorker() {199<-c.numWorkers200}
201