coder/coderd/audit/diff.go

171 lines
3.8 KiB
Go

package audit
import (
"database/sql"
"fmt"
"reflect"
"github.com/google/uuid"
)
// TODO: this might need to be in the database package.
type Map map[string]interface{}
func Empty[T Auditable]() T {
var t T
return t
}
// Diff compares two auditable resources and produces a Map of the changed
// values.
func Diff[T Auditable](left, right T) Map {
// Values are equal, return an empty diff.
if reflect.DeepEqual(left, right) {
return Map{}
}
return diffValues(left, right, AuditableResources)
}
func structName(t reflect.Type) string {
return t.PkgPath() + "." + t.Name()
}
func diffValues[T any](left, right T, table Table) Map {
var (
baseDiff = Map{}
leftV = reflect.ValueOf(left)
rightV = reflect.ValueOf(right)
rightT = reflect.TypeOf(right)
diffKey = table[structName(rightT)]
)
if diffKey == nil {
panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right))
}
for i := 0; i < rightT.NumField(); i++ {
var (
leftF = leftV.Field(i)
rightF = rightV.Field(i)
leftI = leftF.Interface()
rightI = rightF.Interface()
diffName = rightT.Field(i).Tag.Get("json")
)
atype, ok := diffKey[diffName]
if !ok {
panic(fmt.Sprintf("dev error: field %q lacks audit information", diffName))
}
if atype == ActionIgnore {
continue
}
// coerce struct types that would produce bad diffs.
if leftI, rightI, ok = convertDiffType(leftI, rightI); ok {
leftF, rightF = reflect.ValueOf(leftI), reflect.ValueOf(rightI)
}
// If the field is a pointer, dereference it. Nil pointers are coerced
// to the zero value of their underlying type.
if leftF.Kind() == reflect.Ptr && rightF.Kind() == reflect.Ptr {
leftF, rightF = derefPointer(leftF), derefPointer(rightF)
leftI, rightI = leftF.Interface(), rightF.Interface()
}
// Recursively walk up nested structs.
if rightF.Kind() == reflect.Struct {
baseDiff[diffName] = diffValues(leftI, rightI, table)
continue
}
if !reflect.DeepEqual(leftI, rightI) {
switch atype {
case ActionTrack:
baseDiff[diffName] = rightI
case ActionSecret:
baseDiff[diffName] = reflect.Zero(rightF.Type()).Interface()
}
}
}
return baseDiff
}
// convertDiffType converts external struct types to primitive types.
//
//nolint:forcetypeassert
func convertDiffType(left, right any) (newLeft, newRight any, changed bool) {
switch typed := left.(type) {
case uuid.UUID:
return typed.String(), right.(uuid.UUID).String(), true
case uuid.NullUUID:
leftStr, _ := typed.MarshalText()
rightStr, _ := right.(uuid.NullUUID).MarshalText()
return string(leftStr), string(rightStr), true
case sql.NullString:
leftStr := typed.String
if !typed.Valid {
leftStr = "null"
}
rightStr := right.(sql.NullString).String
if !right.(sql.NullString).Valid {
rightStr = "null"
}
return leftStr, rightStr, true
case sql.NullInt64:
var leftInt64Ptr *int64
var rightInt64Ptr *int64
if !typed.Valid {
leftInt64Ptr = nil
} else {
leftInt64Ptr = ptr(typed.Int64)
}
rightInt64Ptr = ptr(right.(sql.NullInt64).Int64)
if !right.(sql.NullInt64).Valid {
rightInt64Ptr = nil
}
return leftInt64Ptr, rightInt64Ptr, true
default:
return left, right, false
}
}
// derefPointer deferences a reflect.Value that is a pointer to its underlying
// value. It dereferences recursively until it finds a non-pointer value. If the
// pointer is nil, it will be coerced to the zero value of the underlying type.
func derefPointer(ptr reflect.Value) reflect.Value {
if !ptr.IsNil() {
// Grab the value the pointer references.
ptr = ptr.Elem()
} else {
// Coerce nil ptrs to zero'd values of their underlying type.
ptr = reflect.Zero(ptr.Type().Elem())
}
// Recursively deref nested pointers.
if ptr.Kind() == reflect.Ptr {
return derefPointer(ptr)
}
return ptr
}
func ptr[T any](x T) *T {
return &x
}