diff --git a/README.md b/README.md index 78aa797f28d7def85568d3f7b2033bce6109938b..d1717dd73ae6bc7ddc1a9b13dd09a981214269a0 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,13 @@ the `group`, `policy name` and `version` are directories inside the Git repo and 2. In the same directory there could be a data file containing static JSON, which is automatically available for use during policy evaluation by using the `data` variable. The file *must* be named `data.json`. Example: `/gaiax/example/1.0/data.json` -3. The policy package name inside the policy source code file *must* exactly match +3. In the same directory there could be a configuration file containing information for getting static JSON +data from external URL. The file *must* be named `data-config.json`. +Example: `/gaiax/example/1.0/data-config.json` +> Note that there should only be one of the two files `data.json` or `data-config.json` in the same directory. +> If both files exist in the same directory tha data from the `data.json` file will be eventually overwritten by the data +> acquired using the configuration from the `data-config.json` file. +4. The policy package name inside the policy source code file *must* exactly match the `group` and `policy` (name) of the policy. *What does it mean?* @@ -164,7 +170,26 @@ Example: If the `/gaiax/example/1.0/data.json` file is: ``` one could access the data using `data.name` within the Rego source code. -The 3rd rule for package naming is needed so that a generic evaluation function +The 3rd rule for configuration file is to provide configurations for getting static JSON data from external URL. +The file must contain a URL, an HTTP method and a period, after which an HTTP request is made to get the latest data. +> The period must be added as duration e.g. `10h`, `1h30m` etc. + +The file MAY contain body for the request. +Example file contents: +```json +{ + "url": "http://example.com/data.json?page=3", + "method": "GET", + "period": "10h", + "body": { + "key": "value" + } +} +``` +This means that every 10 hours an HTTP request is going to be made on the given URL, with `GET` method and the result is going +to be stored as static data for this policy and passed during evaluation. + +The 4th rule for package naming is needed so that a generic evaluation function can be mapped and used for evaluating all kinds of different policies. Without a package naming rule, there's no way the service can automatically generate HTTP endpoints for working with arbitrary dynamically uploaded policies. diff --git a/cmd/policy/main.go b/cmd/policy/main.go index ff64fe269e9ee9d2d8ca6af76ae66bfc9ee02df1..0b26b2f920c22a9409d1443b8a989956cb854906 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -37,6 +37,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/health" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policydata" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" ) @@ -93,6 +94,14 @@ func main() { // subscribe the cache for policy data changes storage.AddPolicyChangeSubscriber(regocache) + // create policy data refresher + dataRefresher := policydata.NewRefresher( + storage, + cfg.Refresher.PollInterval, + httpClient, + logger, + ) + // register rego extension functions { cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, oauthClient) @@ -203,6 +212,9 @@ func main() { } return nil }) + g.Go(func() error { + return dataRefresher.Start(ctx) + }) if err := g.Wait(); err != nil { logger.Error("run group stopped", zap.Error(err)) } diff --git a/cmd/sync/main.go b/cmd/sync/main.go index a14cc9b98c42cc01c36d8eb934b004e4c6161daf..42ba5d48bafe3d9621b73e7e529353ecc89f7866 100644 --- a/cmd/sync/main.go +++ b/cmd/sync/main.go @@ -296,7 +296,7 @@ func upsert(ctx context.Context, policies []*Policy, db *mongo.Collection) error "data": policy.Data, "dataConfig": policy.DataConfig, "lastUpdate": time.Now(), - "nextConfigExecution": nextConfigExecution(policy), + "nextDataRefreshTime": nextDataRefreshTime(policy), }, }) op.SetUpsert(true) @@ -311,7 +311,7 @@ func upsert(ctx context.Context, policies []*Policy, db *mongo.Collection) error return nil } -func nextConfigExecution(p *Policy) time.Time { +func nextDataRefreshTime(p *Policy) time.Time { if p.DataConfig != "" { return time.Now() } diff --git a/internal/config/config.go b/internal/config/config.go index 1aedbedfb7a7ec5a972b074400e9fcf598f15074..c092a9fa560f58465b470cdd1d250789bd682154 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,6 +12,7 @@ type Config struct { Metrics metricsConfig OCM ocmConfig OAuth oauthConfig + Refresher refresherConfig LogLevel string `envconfig:"LOG_LEVEL" default:"INFO"` } @@ -61,3 +62,7 @@ type oauthConfig struct { ClientSecret string `envconfig:"OAUTH_CLIENT_SECRET" required:"true"` TokenURL string `envconfig:"OAUTH_TOKEN_URL" required:"true"` } + +type refresherConfig struct { + PollInterval time.Duration `envconfig:"REFRESHER_POLL_INTERVAL" default:"10s"` +} diff --git a/internal/service/policy/policydata/dataconfig.go b/internal/service/policy/policydata/dataconfig.go new file mode 100644 index 0000000000000000000000000000000000000000..fcffacac189e076d615630eaceb706a4e6c99df6 --- /dev/null +++ b/internal/service/policy/policydata/dataconfig.go @@ -0,0 +1,38 @@ +package policydata + +import ( + "encoding/json" + "time" + + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" +) + +type DataConfig struct { + URL string + Method string + Period Duration + Body interface{} +} + +type Duration time.Duration + +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + *d = Duration(time.Duration(value)) + return nil + case string: + tmp, err := time.ParseDuration(value) + if err != nil { + return err + } + *d = Duration(tmp) + return nil + default: + return errors.New("invalid duration") + } +} diff --git a/internal/service/policy/policydata/policydatafakes/fake_storage.go b/internal/service/policy/policydata/policydatafakes/fake_storage.go new file mode 100644 index 0000000000000000000000000000000000000000..521c346f7a4fe49d8ce765aabc1bfbf4270d4fea --- /dev/null +++ b/internal/service/policy/policydata/policydatafakes/fake_storage.go @@ -0,0 +1,278 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package policydatafakes + +import ( + "context" + "sync" + "time" + + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policydata" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" +) + +type FakeStorage struct { + GetRefreshPoliciesStub func(context.Context) ([]*storage.Policy, error) + getRefreshPoliciesMutex sync.RWMutex + getRefreshPoliciesArgsForCall []struct { + arg1 context.Context + } + getRefreshPoliciesReturns struct { + result1 []*storage.Policy + result2 error + } + getRefreshPoliciesReturnsOnCall map[int]struct { + result1 []*storage.Policy + result2 error + } + PostponeRefreshStub func(context.Context, []*storage.Policy) error + postponeRefreshMutex sync.RWMutex + postponeRefreshArgsForCall []struct { + arg1 context.Context + arg2 []*storage.Policy + } + postponeRefreshReturns struct { + result1 error + } + postponeRefreshReturnsOnCall map[int]struct { + result1 error + } + UpdateNextRefreshTimeStub func(context.Context, *storage.Policy, time.Time) error + updateNextRefreshTimeMutex sync.RWMutex + updateNextRefreshTimeArgsForCall []struct { + arg1 context.Context + arg2 *storage.Policy + arg3 time.Time + } + updateNextRefreshTimeReturns struct { + result1 error + } + updateNextRefreshTimeReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeStorage) GetRefreshPolicies(arg1 context.Context) ([]*storage.Policy, error) { + fake.getRefreshPoliciesMutex.Lock() + ret, specificReturn := fake.getRefreshPoliciesReturnsOnCall[len(fake.getRefreshPoliciesArgsForCall)] + fake.getRefreshPoliciesArgsForCall = append(fake.getRefreshPoliciesArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.GetRefreshPoliciesStub + fakeReturns := fake.getRefreshPoliciesReturns + fake.recordInvocation("GetRefreshPolicies", []interface{}{arg1}) + fake.getRefreshPoliciesMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeStorage) GetRefreshPoliciesCallCount() int { + fake.getRefreshPoliciesMutex.RLock() + defer fake.getRefreshPoliciesMutex.RUnlock() + return len(fake.getRefreshPoliciesArgsForCall) +} + +func (fake *FakeStorage) GetRefreshPoliciesCalls(stub func(context.Context) ([]*storage.Policy, error)) { + fake.getRefreshPoliciesMutex.Lock() + defer fake.getRefreshPoliciesMutex.Unlock() + fake.GetRefreshPoliciesStub = stub +} + +func (fake *FakeStorage) GetRefreshPoliciesArgsForCall(i int) context.Context { + fake.getRefreshPoliciesMutex.RLock() + defer fake.getRefreshPoliciesMutex.RUnlock() + argsForCall := fake.getRefreshPoliciesArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeStorage) GetRefreshPoliciesReturns(result1 []*storage.Policy, result2 error) { + fake.getRefreshPoliciesMutex.Lock() + defer fake.getRefreshPoliciesMutex.Unlock() + fake.GetRefreshPoliciesStub = nil + fake.getRefreshPoliciesReturns = struct { + result1 []*storage.Policy + result2 error + }{result1, result2} +} + +func (fake *FakeStorage) GetRefreshPoliciesReturnsOnCall(i int, result1 []*storage.Policy, result2 error) { + fake.getRefreshPoliciesMutex.Lock() + defer fake.getRefreshPoliciesMutex.Unlock() + fake.GetRefreshPoliciesStub = nil + if fake.getRefreshPoliciesReturnsOnCall == nil { + fake.getRefreshPoliciesReturnsOnCall = make(map[int]struct { + result1 []*storage.Policy + result2 error + }) + } + fake.getRefreshPoliciesReturnsOnCall[i] = struct { + result1 []*storage.Policy + result2 error + }{result1, result2} +} + +func (fake *FakeStorage) PostponeRefresh(arg1 context.Context, arg2 []*storage.Policy) error { + var arg2Copy []*storage.Policy + if arg2 != nil { + arg2Copy = make([]*storage.Policy, len(arg2)) + copy(arg2Copy, arg2) + } + fake.postponeRefreshMutex.Lock() + ret, specificReturn := fake.postponeRefreshReturnsOnCall[len(fake.postponeRefreshArgsForCall)] + fake.postponeRefreshArgsForCall = append(fake.postponeRefreshArgsForCall, struct { + arg1 context.Context + arg2 []*storage.Policy + }{arg1, arg2Copy}) + stub := fake.PostponeRefreshStub + fakeReturns := fake.postponeRefreshReturns + fake.recordInvocation("PostponeRefresh", []interface{}{arg1, arg2Copy}) + fake.postponeRefreshMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorage) PostponeRefreshCallCount() int { + fake.postponeRefreshMutex.RLock() + defer fake.postponeRefreshMutex.RUnlock() + return len(fake.postponeRefreshArgsForCall) +} + +func (fake *FakeStorage) PostponeRefreshCalls(stub func(context.Context, []*storage.Policy) error) { + fake.postponeRefreshMutex.Lock() + defer fake.postponeRefreshMutex.Unlock() + fake.PostponeRefreshStub = stub +} + +func (fake *FakeStorage) PostponeRefreshArgsForCall(i int) (context.Context, []*storage.Policy) { + fake.postponeRefreshMutex.RLock() + defer fake.postponeRefreshMutex.RUnlock() + argsForCall := fake.postponeRefreshArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeStorage) PostponeRefreshReturns(result1 error) { + fake.postponeRefreshMutex.Lock() + defer fake.postponeRefreshMutex.Unlock() + fake.PostponeRefreshStub = nil + fake.postponeRefreshReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorage) PostponeRefreshReturnsOnCall(i int, result1 error) { + fake.postponeRefreshMutex.Lock() + defer fake.postponeRefreshMutex.Unlock() + fake.PostponeRefreshStub = nil + if fake.postponeRefreshReturnsOnCall == nil { + fake.postponeRefreshReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.postponeRefreshReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeStorage) UpdateNextRefreshTime(arg1 context.Context, arg2 *storage.Policy, arg3 time.Time) error { + fake.updateNextRefreshTimeMutex.Lock() + ret, specificReturn := fake.updateNextRefreshTimeReturnsOnCall[len(fake.updateNextRefreshTimeArgsForCall)] + fake.updateNextRefreshTimeArgsForCall = append(fake.updateNextRefreshTimeArgsForCall, struct { + arg1 context.Context + arg2 *storage.Policy + arg3 time.Time + }{arg1, arg2, arg3}) + stub := fake.UpdateNextRefreshTimeStub + fakeReturns := fake.updateNextRefreshTimeReturns + fake.recordInvocation("UpdateNextRefreshTime", []interface{}{arg1, arg2, arg3}) + fake.updateNextRefreshTimeMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorage) UpdateNextRefreshTimeCallCount() int { + fake.updateNextRefreshTimeMutex.RLock() + defer fake.updateNextRefreshTimeMutex.RUnlock() + return len(fake.updateNextRefreshTimeArgsForCall) +} + +func (fake *FakeStorage) UpdateNextRefreshTimeCalls(stub func(context.Context, *storage.Policy, time.Time) error) { + fake.updateNextRefreshTimeMutex.Lock() + defer fake.updateNextRefreshTimeMutex.Unlock() + fake.UpdateNextRefreshTimeStub = stub +} + +func (fake *FakeStorage) UpdateNextRefreshTimeArgsForCall(i int) (context.Context, *storage.Policy, time.Time) { + fake.updateNextRefreshTimeMutex.RLock() + defer fake.updateNextRefreshTimeMutex.RUnlock() + argsForCall := fake.updateNextRefreshTimeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeStorage) UpdateNextRefreshTimeReturns(result1 error) { + fake.updateNextRefreshTimeMutex.Lock() + defer fake.updateNextRefreshTimeMutex.Unlock() + fake.UpdateNextRefreshTimeStub = nil + fake.updateNextRefreshTimeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorage) UpdateNextRefreshTimeReturnsOnCall(i int, result1 error) { + fake.updateNextRefreshTimeMutex.Lock() + defer fake.updateNextRefreshTimeMutex.Unlock() + fake.UpdateNextRefreshTimeStub = nil + if fake.updateNextRefreshTimeReturnsOnCall == nil { + fake.updateNextRefreshTimeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateNextRefreshTimeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeStorage) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getRefreshPoliciesMutex.RLock() + defer fake.getRefreshPoliciesMutex.RUnlock() + fake.postponeRefreshMutex.RLock() + defer fake.postponeRefreshMutex.RUnlock() + fake.updateNextRefreshTimeMutex.RLock() + defer fake.updateNextRefreshTimeMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeStorage) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ policydata.Storage = new(FakeStorage) diff --git a/internal/service/policy/policydata/refresher.go b/internal/service/policy/policydata/refresher.go new file mode 100644 index 0000000000000000000000000000000000000000..c5a7127c2ac7232edd4ffcca90d0acb72d17e2a3 --- /dev/null +++ b/internal/service/policy/policydata/refresher.go @@ -0,0 +1,152 @@ +package policydata + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "time" + + "go.uber.org/zap" + + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" +) + +//go:generate counterfeiter . Storage + +type Storage interface { + GetRefreshPolicies(ctx context.Context) ([]*storage.Policy, error) + PostponeRefresh(ctx context.Context, policies []*storage.Policy) error + UpdateNextRefreshTime(ctx context.Context, p *storage.Policy, nextDataRefreshTime time.Time) error +} + +type Refresher struct { + storage Storage + pollInterval time.Duration + + httpClient *http.Client + logger *zap.Logger +} + +func NewRefresher( + storage Storage, + pollInterval time.Duration, + httpClient *http.Client, + logger *zap.Logger, +) *Refresher { + return &Refresher{ + storage: storage, + pollInterval: pollInterval, + httpClient: httpClient, + logger: logger, + } +} + +func (e *Refresher) Start(ctx context.Context) error { + defer e.logger.Info("policy data refresher stopped") + +loop: + for { + select { + case <-ctx.Done(): + break loop + case <-time.After(e.pollInterval): + policies, err := e.storage.GetRefreshPolicies(ctx) + if err != nil { + if !errors.Is(errors.NotFound, err) { + e.logger.Error("error getting policies for data refresh from storage", zap.Error(err)) + } + continue + } + for _, policy := range policies { + e.Execute(ctx, policy) + } + } + } + + return ctx.Err() +} + +func (e *Refresher) Execute(ctx context.Context, p *storage.Policy) { + logger := e.logger.With( + zap.String("policyName", p.Name), + zap.String("policyGroup", p.Group), + zap.String("policyVersion", p.Version), + ) + + var config DataConfig + if err := json.Unmarshal([]byte(p.DataConfig), &config); err != nil { + // data configuration is corrupted, set next refresh time to Go's zero date + _ = e.storage.UpdateNextRefreshTime(ctx, p, time.Time{}) + logger.Error("error unmarshalling data configuration", zap.Error(err)) + return + } + if config.URL == "" || config.Period == Duration(0) || config.Method == "" { + // data configuration is missing required fields, set next refresh time to Go's zero date + _ = e.storage.UpdateNextRefreshTime(ctx, p, time.Time{}) + logger.Error("required fields are missing in data configuration") + return + } + + req, err := e.createHTTPRequest(ctx, &config) + if err != nil { + // cannot create a request, set next refresh time to Go's zero date + _ = e.storage.UpdateNextRefreshTime(ctx, p, time.Time{}) + logger.Error("error creating an http request", zap.Error(err)) + return + } + + resp, err := e.httpClient.Do(req) + if err != nil { + // making data configuration request failed, set next refresh time to current time added data config's period + _ = e.storage.UpdateNextRefreshTime(ctx, p, time.Now().Add(time.Duration(config.Period))) + logger.Error("error making a data refresh request", zap.Error(err)) + return + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK { + // unexpected response on data refresh request, set next refresh time to current time added data config's period + _ = e.storage.UpdateNextRefreshTime(ctx, p, time.Now().Add(time.Duration(config.Period))) + logger.Error("unexpected response on data refresh request", zap.Int("response code", resp.StatusCode)) + return + } + + dataBytes, err := io.ReadAll(resp.Body) + if err != nil { + // error reading response from data refresh request, set next refresh time to current time added data config's period + _ = e.storage.UpdateNextRefreshTime(ctx, p, time.Now().Add(time.Duration(config.Period))) + logger.Error("error reading response from data refresh request", zap.Error(err)) + return + } + + p.Data = string(dataBytes) + if err = e.storage.UpdateNextRefreshTime(ctx, p, time.Now().Add(time.Duration(config.Period))); err != nil { + logger.Error("error updating data after successful refresh request", zap.Error(err)) + return + } + logger.Debug("data refresh is successfully executed") +} + +func (e *Refresher) createHTTPRequest(ctx context.Context, config *DataConfig) (*http.Request, error) { + bodyBytes, err := json.Marshal(config.Body) + if err != nil { + return nil, errors.New("error marshaling data configuration body") + } + + url, err := url.Parse(config.URL) + if err != nil { + return nil, errors.New("invalid data configuration url") + } + if url.Scheme == "" { + url.Scheme = "https" + } + + if config.Method == http.MethodPost { + return http.NewRequestWithContext(ctx, config.Method, url.String(), bytes.NewReader(bodyBytes)) + } + return http.NewRequestWithContext(ctx, config.Method, url.String(), nil) +} diff --git a/internal/service/policy/policydata/refresher_test.go b/internal/service/policy/policydata/refresher_test.go new file mode 100644 index 0000000000000000000000000000000000000000..59492f0c54b40a2de48eefebbe3b677eaecf6019 --- /dev/null +++ b/internal/service/policy/policydata/refresher_test.go @@ -0,0 +1,133 @@ +package policydata_test + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" + + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policydata" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policydata/policydatafakes" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" +) + +type RoundTripFunc func(req *http.Request) *http.Response + +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +// NewTestClient returns *http.Client with Transport replaced to avoid making real calls +func NewTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +func Test_Execute(t *testing.T) { + tests := []struct { + // test input + name string + statusCode int + policy storage.Policy + storage policydata.Storage + // expected result + logCnt int + firstLog string + }{ + { + name: "invalid data configuration", + policy: storage.Policy{DataConfig: "<invalid data configuration>"}, + storage: &policydatafakes.FakeStorage{ + UpdateNextRefreshTimeStub: func(ctx context.Context, policy *storage.Policy, t time.Time) error { + return nil + }, + }, + logCnt: 1, + firstLog: "error unmarshalling data configuration", + }, + { + name: "data configuration is missing required fields", + policy: storage.Policy{DataConfig: `{"url": "https://example.com"}`}, + storage: &policydatafakes.FakeStorage{ + UpdateNextRefreshTimeStub: func(ctx context.Context, policy *storage.Policy, t time.Time) error { + return nil + }, + }, + logCnt: 1, + firstLog: "required fields are missing in data configuration", + }, + { + name: "error making an http request", + policy: storage.Policy{DataConfig: `{"url": "htt//example.com", "method": "GET", "period": "1h"}`}, + storage: &policydatafakes.FakeStorage{ + UpdateNextRefreshTimeStub: func(ctx context.Context, policy *storage.Policy, t time.Time) error { + return nil + }, + }, + logCnt: 1, + firstLog: "error making a data refresh request", + }, + { + name: "unexpected response code", + statusCode: 500, + policy: storage.Policy{DataConfig: `{"url": "https://example.com", "method": "GET", "period": "1h"}`}, + storage: &policydatafakes.FakeStorage{ + UpdateNextRefreshTimeStub: func(ctx context.Context, policy *storage.Policy, t time.Time) error { + return nil + }, + }, + logCnt: 1, + firstLog: "unexpected response on data refresh request", + }, + { + name: "error updating data after successful refresh request", + policy: storage.Policy{DataConfig: `{"url": "https://example.com", "method": "GET", "period": "1h"}`}, + storage: &policydatafakes.FakeStorage{ + UpdateNextRefreshTimeStub: func(ctx context.Context, policy *storage.Policy, t time.Time) error { + return errors.New("storage error") + }, + }, + logCnt: 1, + firstLog: "error updating data after successful refresh request", + }, + { + name: "data refresh is successfully executed", + policy: storage.Policy{DataConfig: `{"url": "https://example.com", "method": "GET", "period": "1h"}`}, + storage: &policydatafakes.FakeStorage{ + UpdateNextRefreshTimeStub: func(ctx context.Context, policy *storage.Policy, t time.Time) error { + return nil + }, + }, + logCnt: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + observedZapCore, observedLogs := observer.New(zap.ErrorLevel) + logger := zap.New(observedZapCore) + httpClient := http.DefaultClient + if test.statusCode != 0 { + httpClient = NewTestClient(func(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: test.statusCode, + } + }) + } + refresher := policydata.NewRefresher(test.storage, time.Duration(0), httpClient, logger) + refresher.Execute(context.Background(), &test.policy) + + assert.Equal(t, test.logCnt, observedLogs.Len()) + if observedLogs.Len() > 0 { + firstLog := observedLogs.All()[0] + assert.Equal(t, test.firstLog, firstLog.Message) + } + }) + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index fed1d6a1481b87f63e4efc349124f7b06933d95e..2901bfe0b05f6f448bfc0b014bcb039ad2b12007 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -6,28 +6,39 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" zap "go.uber.org/zap" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" ) +const ( + dataField = "data" + nextDataRefreshTimeField = "nextDataRefreshTime" + refreshPostponePeriod = 5 * time.Minute +) + type PolicyChangeSubscriber interface { PolicyDataChange() } type Policy struct { - Filename string - Name string - Group string - Version string - Rego string - Data string - Locked bool - LastUpdate time.Time + ID primitive.ObjectID `bson:"_id"` + Filename string + Name string + Group string + Version string + Rego string + Data string + DataConfig string + Locked bool + LastUpdate time.Time + NextDataRefreshTime time.Time } type Storage struct { + db *mongo.Client policy *mongo.Collection subscriber PolicyChangeSubscriber logger *zap.Logger @@ -39,6 +50,7 @@ func New(db *mongo.Client, dbname, collection string, logger *zap.Logger) (*Stor } return &Storage{ + db: db, policy: db.Database(dbname).Collection(collection), logger: logger, }, nil @@ -109,3 +121,84 @@ func (s *Storage) ListenPolicyDataChanges(ctx context.Context) error { func (s *Storage) AddPolicyChangeSubscriber(subscriber PolicyChangeSubscriber) { s.subscriber = subscriber } + +func (s *Storage) GetRefreshPolicies(ctx context.Context) ([]*Policy, error) { + // create a callback for the mongodb transaction + callback := func(mCtx mongo.SessionContext) (interface{}, error) { + filter := bson.M{nextDataRefreshTimeField: bson.M{ + "$gt": time.Time{}, // greater than the Go's zero date + "$lte": time.Now(), + }} + + cursor, err := s.policy.Find(ctx, filter) + if err != nil { + return nil, err + } + + var policies []*Policy + if err := cursor.All(ctx, &policies); err != nil { + return nil, err + } + if len(policies) == 0 { + return nil, errors.New(errors.NotFound, "policies for data refresh not found") + } + + err = s.PostponeRefresh(ctx, policies) + if err != nil { + return nil, err + } + + return policies, nil + } + + // execute transaction + res, err := s.Transaction(ctx, callback) + if err != nil { + return nil, err + } + policies, _ := res.([]*Policy) + + return policies, nil +} + +// PostponeRefresh adds a refreshPostponePeriod Duration to each policy's +// nextDataRefreshTimeField in order to prevent concurrent data refresh +func (s *Storage) PostponeRefresh(ctx context.Context, policies []*Policy) error { + var ids []*primitive.ObjectID + for _, p := range policies { + ids = append(ids, &p.ID) + } + + filter := bson.M{"_id": bson.M{"$in": ids}} + update := bson.M{"$set": bson.M{nextDataRefreshTimeField: time.Now().Add(refreshPostponePeriod)}} + _, err := s.policy.UpdateMany(ctx, filter, update) + + return err +} + +// UpdateNextRefreshTime updates policy's data and nextDataRefreshTimeField fields +func (s *Storage) UpdateNextRefreshTime(ctx context.Context, p *Policy, nextDataRefreshTime time.Time) error { + filter := bson.M{"_id": p.ID} + update := bson.M{"$set": bson.M{ + nextDataRefreshTimeField: nextDataRefreshTime, + dataField: p.Data, + }} + _, err := s.policy.UpdateOne(ctx, filter, update) + + return err +} + +func (s *Storage) Transaction(ctx context.Context, callback func(mCtx mongo.SessionContext) (interface{}, error)) (interface{}, error) { + session, err := s.db.StartSession() + if err != nil { + return nil, errors.New("failed creating session", err) + } + defer session.EndSession(ctx) + + res, err := session.WithTransaction(ctx, callback) + if err != nil { + return nil, errors.New("failed executing transaction", err) + } + + return res, nil +} diff --git a/vendor/go.uber.org/zap/zaptest/observer/logged_entry.go b/vendor/go.uber.org/zap/zaptest/observer/logged_entry.go new file mode 100644 index 0000000000000000000000000000000000000000..a4ea7ec36c1e4162f9a56e17183d78be218c4147 Binary files /dev/null and b/vendor/go.uber.org/zap/zaptest/observer/logged_entry.go differ diff --git a/vendor/go.uber.org/zap/zaptest/observer/observer.go b/vendor/go.uber.org/zap/zaptest/observer/observer.go new file mode 100644 index 0000000000000000000000000000000000000000..f77f1308baf29fa3699df5d3cb0eb07763f50d9b Binary files /dev/null and b/vendor/go.uber.org/zap/zaptest/observer/observer.go differ diff --git a/vendor/modules.txt b/vendor/modules.txt index d89f7493d4e8b469894bad11de6f6e18926e9768..3ba5a11ff2e5cdc00a682f5607bfc9e33d09ba15 100644 Binary files a/vendor/modules.txt and b/vendor/modules.txt differ