Skip to content

Commit

Permalink
Add CreatePackage persistence method
Browse files Browse the repository at this point in the history
- Add an `WithEvents()` decorator for the persistence Service to add event
  publishing for persistence events
- Add a persistence Service `CreatePackage` method
- Add an ent client implementation of `CreatePackage`
- Add a eventManager `CreatePackage` decorator method
- Add an `convertPkgToPackage()` function to the ent client to convert
  ent `db.Pkg` data objects to `package_.Package` objects
- Add error functions to convert db specific errors to more general
  persistence errors to prevent leaking implementation details
  • Loading branch information
djjuhasz committed Sep 8, 2023
1 parent f7bd81d commit a5bf746
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 21 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/go-sql-driver/mysql v1.7.1
github.com/golang-migrate/migrate/v4 v4.16.2
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.9
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/jmoiron/sqlx v1.3.5
Expand Down Expand Up @@ -87,7 +88,6 @@ require (
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/wire v0.5.0 // indirect
github.com/googleapis/gax-go/v2 v2.11.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
Expand Down
62 changes: 59 additions & 3 deletions internal/persistence/ent/client/client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package entclient

import (
"context"
"time"

"github.com/go-logr/logr"
"github.com/google/uuid"

"github.com/artefactual-sdps/enduro/internal/package_"
"github.com/artefactual-sdps/enduro/internal/persistence"
"github.com/artefactual-sdps/enduro/internal/persistence/ent/db"
)

const TimeFormat = time.RFC3339

type client struct {
logger logr.Logger
ent *db.Client
Expand All @@ -15,8 +22,57 @@ type client struct {
var _ persistence.Service = (*client)(nil)

func New(logger logr.Logger, ent *db.Client) persistence.Service {
return &client{
logger: logger,
ent: ent,
return &client{logger: logger, ent: ent}
}

func (c *client) CreatePackage(ctx context.Context, pkg *package_.Package) (*package_.Package, error) {
// Validate required fields.
if pkg.Name == "" {
return nil, newRequiredFieldError("Name")
}
if pkg.WorkflowID == "" {
return nil, newRequiredFieldError("WorkflowID")
}

if pkg.RunID == "" {
return nil, newRequiredFieldError("RunID")
}
runID, err := uuid.Parse(pkg.RunID)
if err != nil {
return nil, newParseError(err, "RunID")
}

if pkg.AIPID == "" {
return nil, newRequiredFieldError("AIPID")
}
aipID, err := uuid.Parse(pkg.AIPID)
if err != nil {
return nil, newParseError(err, "AIPID")
}

q := c.ent.Pkg.Create().
SetName(pkg.Name).
SetWorkflowID(pkg.WorkflowID).
SetRunID(runID).
SetAipID(aipID).
SetStatus(int8(pkg.Status))

// Add optional fields.
if pkg.LocationID.Valid {
q.SetLocationID(pkg.LocationID.UUID)
}
if pkg.StartedAt.Valid {
q.SetStartedAt(pkg.StartedAt.Time)
}
if pkg.CompletedAt.Valid {
q.SetCompletedAt(pkg.CompletedAt.Time)
}

// Set CreatedAt and Save package.
p, err := q.SetCreatedAt(time.Now()).Save(ctx)
if err != nil {
return nil, newDBErrorWithDetails(err, "create package")
}

return convertPkgToPackage(p), nil
}
185 changes: 169 additions & 16 deletions internal/persistence/ent/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package entclient_test

import (
"context"
"database/sql"
"fmt"
"testing"
"time"

"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
"gotest.tools/v3/assert"
Expand All @@ -30,24 +33,174 @@ func setUpClient(t *testing.T, logger logr.Logger) (*db.Client, persistence.Serv
}

func TestNew(t *testing.T) {
t.Parallel()
t.Run("Returns a working ent DB client", func(t *testing.T) {
t.Parallel()

entc, _ := setUpClient(t, logr.Discard())
entc, _ := setUpClient(t, logr.Discard())
runID := uuid.New()
aipID := uuid.New()

p, err := entc.Pkg.Create().
SetName("testing 1-2-3").
SetWorkflowID("12345").
SetRunID(runID).
SetAipID(aipID).
SetStatus(int8(package_.NewStatus("in progress"))).
Save(context.Background())

assert.NilError(t, err)
assert.Equal(t, p.Name, "testing 1-2-3")
assert.Equal(t, p.WorkflowID, "12345")
assert.Equal(t, p.RunID, runID)
assert.Equal(t, p.AipID, aipID)
assert.Equal(t, p.Status, int8(package_.StatusInProgress))
})
}

func TestCreatePackage(t *testing.T) {
runID := uuid.New()
aipID := uuid.New()
locID := uuid.NullUUID{UUID: uuid.New(), Valid: true}
started := sql.NullTime{Time: time.Now(), Valid: true}
completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true}

type params struct {
pkg *package_.Package
}
tests := []struct {
name string
args params
want *package_.Package
wantErr string
}{
{
name: "Saves a new package in the DB",
args: params{
pkg: &package_.Package{
Name: "Test package 1",
WorkflowID: "workflow-1",
RunID: runID.String(),
AIPID: aipID.String(),
LocationID: locID,
Status: package_.StatusInProgress,
StartedAt: started,
CompletedAt: completed,
},
},
want: &package_.Package{
ID: 1,
Name: "Test package 1",
WorkflowID: "workflow-1",
RunID: runID.String(),
AIPID: aipID.String(),
LocationID: locID,
Status: package_.StatusInProgress,
CreatedAt: time.Now(),
StartedAt: started,
CompletedAt: completed,
},
},
{
name: "Saves a package with missing optional fields",
args: params{
pkg: &package_.Package{
Name: "Test package 2",
WorkflowID: "workflow-2",
RunID: runID.String(),
AIPID: aipID.String(),
Status: package_.StatusInProgress,
},
},
want: &package_.Package{
ID: 1,
Name: "Test package 2",
WorkflowID: "workflow-2",
RunID: runID.String(),
AIPID: aipID.String(),
Status: package_.StatusInProgress,
CreatedAt: time.Now(),
},
},
{
name: "Required field error for missing Name",
args: params{
pkg: &package_.Package{},
},
wantErr: "invalid data error: field \"Name\" is required",
},
{
name: "Required field error for missing WorkflowID",
args: params{
pkg: &package_.Package{
Name: "Missing WorkflowID",
},
},
wantErr: "invalid data error: field \"WorkflowID\" is required",
},
{
name: "Required field error for missing AIPID",
args: params{
pkg: &package_.Package{
Name: "Missing AIPID",
WorkflowID: "workflow-12345",
RunID: runID.String(),
},
},
wantErr: "invalid data error: field \"AIPID\" is required",
},
{
name: "Required field error for missing RunID",
args: params{
pkg: &package_.Package{
Name: "Missing RunID",
WorkflowID: "workflow-12345",
},
},
wantErr: "invalid data error: field \"RunID\" is required",
},
{
name: "Errors on invalid RunID",
args: params{
pkg: &package_.Package{
Name: "Invalid package 1",
WorkflowID: "workflow-invalid",
RunID: "Bad UUID",
},
},
wantErr: "invalid data error: parse error: field \"RunID\": invalid UUID length: 8",
},
{
name: "Errors on invalid AIPID",
args: params{
pkg: &package_.Package{
Name: "Invalid package 2",
WorkflowID: "workflow-invalid",
RunID: runID.String(),
AIPID: "Bad UUID",
},
},
wantErr: "invalid data error: parse error: field \"AIPID\": invalid UUID length: 8",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

_, svc := setUpClient(t, logr.Discard())
ctx := context.Background()

pkg, err := svc.CreatePackage(ctx, tt.args.pkg)
if tt.wantErr != "" {
assert.Error(t, err, tt.wantErr)
return
}
assert.NilError(t, err)

p, err := entc.Pkg.Create().
SetName("testing 1-2-3").
SetWorkflowID("12345").
SetRunID(runID).
SetAipID(aipID).
SetStatus(int8(package_.NewStatus("in progress"))).
Save(context.Background())

assert.NilError(t, err)
assert.Equal(t, p.Name, "testing 1-2-3")
assert.Equal(t, p.WorkflowID, "12345")
assert.Equal(t, p.RunID, runID)
assert.Equal(t, p.AipID, aipID)
assert.Equal(t, p.Status, int8(package_.StatusInProgress))
assert.DeepEqual(t, pkg, tt.want,
cmpopts.EquateApproxTime(time.Millisecond*100),
cmpopts.IgnoreUnexported(db.Pkg{}, db.PkgEdges{}),
)
})
}
}
40 changes: 40 additions & 0 deletions internal/persistence/ent/client/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package entclient

import (
"database/sql"

"github.com/google/uuid"

"github.com/artefactual-sdps/enduro/internal/package_"
"github.com/artefactual-sdps/enduro/internal/persistence/ent/db"
)

// convertPkgToPackage converts an ent `db.Pkg` package representation to a
// `package_.Package` representation.
func convertPkgToPackage(pkg *db.Pkg) *package_.Package {
var started, completed sql.NullTime
if !pkg.StartedAt.IsZero() {
started = sql.NullTime{Time: pkg.StartedAt, Valid: true}
}
if !pkg.CompletedAt.IsZero() {
completed = sql.NullTime{Time: pkg.CompletedAt, Valid: true}
}

var locID uuid.NullUUID
if pkg.LocationID != uuid.Nil {
locID = uuid.NullUUID{UUID: pkg.LocationID, Valid: true}
}

return &package_.Package{
ID: uint(pkg.ID),
Name: pkg.Name,
LocationID: locID,
Status: package_.Status(pkg.Status),
WorkflowID: pkg.WorkflowID,
RunID: pkg.RunID.String(),
AIPID: pkg.AipID.String(),
CreatedAt: pkg.CreatedAt,
StartedAt: started,
CompletedAt: completed,
}
}
52 changes: 52 additions & 0 deletions internal/persistence/ent/client/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package entclient

import (
"fmt"

"github.com/artefactual-sdps/enduro/internal/persistence"
"github.com/artefactual-sdps/enduro/internal/persistence/ent/db"
)

func newDBError(err error) error {
if err == nil {
return nil
}

var pErr error
switch {
case db.IsNotFound(err):
pErr = persistence.ErrNotFound
case db.IsConstraintError(err):
pErr = persistence.ErrNotValid
case db.IsValidationError(err):
pErr = persistence.ErrNotValid
case db.IsNotLoaded(err):
pErr = persistence.ErrInternal
case db.IsNotSingular(err):
pErr = persistence.ErrInternal
default:
pErr = persistence.ErrInternal
}

return fmt.Errorf("%w: %s", pErr, err)
}

func newDBErrorWithDetails(err error, details string) error {
if err == nil {
return nil
}

return fmt.Errorf("%w: %s", newDBError(err), details)
}

func newRequiredFieldError(field string) error {
return fmt.Errorf("%w: field %q is required", persistence.ErrNotValid, field)
}

func newParseError(err error, field string) error {
if err == nil {
return nil
}

return fmt.Errorf("%w: parse error: field %q: %v", persistence.ErrNotValid, field, err)
}
Loading

0 comments on commit a5bf746

Please sign in to comment.