package simpleauth import ( "bytes" "crypto/sha256" "errors" "fmt" "math/rand" ) var ( // All password storage functions are set here to allow a user to write their // own implementation if required. SaltAndHashPasswordString func(password []byte) (salt, hash []byte) = DefaultSaltAndHashPasswordString SaltPasswordString func(password, salt []byte) []byte = DefaultSaltPasswordString HashSaltedPasswordString func(password []byte) []byte = DefaultHashSaltedPasswordString NumberChars = "0123456789" SpecialChars = "~`!@#$%^&*()_-+={}[]|\\:;\"'<>,./?" LowercaseChars = "abcdefghijklmnopqrstuvwxyz" UppercaseChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" MinPasswordLength = 8 MaxPasswordLength = 256 PasswordIsComplexEnough func([]byte) error = DefaultPasswordIsComplexEnough ErrInvalidPassword = errors.New("invalid password") ) type Password struct { Hash []byte Salt []byte } func (p Password) String() string { return "cowardly refusal to print password hash/salt to stdout" } func (p Password) Matches(password []byte) bool { return bytes.Equal(p.Hash, HashSaltedPasswordString(SaltPasswordString(password, p.Salt))) } func NewPassword(password []byte) (*Password, error) { if err := PasswordIsComplexEnough(password); err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidPassword, err) } salt, hash := SaltAndHashPasswordString(password) return &Password{Hash: hash, Salt: salt}, nil } func DefaultHashSaltedPasswordString(saltedPassword []byte) []byte { return []byte(fmt.Sprintf("%064x", sha256.Sum256([]byte(saltedPassword)))) } func DefaultSaltPasswordString(password, salt []byte) []byte { return []byte(fmt.Sprintf("%s:%s", salt, password)) } func DefaultSaltAndHashPasswordString(password []byte) (salt, hash []byte) { salt = []byte(fmt.Sprintf("%016x", rand.Uint64())) hash = HashSaltedPasswordString(SaltPasswordString(password, salt)) return salt, hash } func DefaultPasswordIsComplexEnough(password []byte) error { errs := make([]error, 0, 5) if len := len(password); len < MinPasswordLength { errs = append(errs, fmt.Errorf("requires >= %d chars (got %d)", MinPasswordLength, len)) } else if len > MaxPasswordLength { errs = append(errs, fmt.Errorf("requires <= %d chars (got %d)", MaxPasswordLength, len)) } if !bytes.ContainsAny(password, LowercaseChars) { errs = append(errs, fmt.Errorf("requires 1+ lowercase char (%s)", LowercaseChars)) } if !bytes.ContainsAny(password, UppercaseChars) { errs = append(errs, fmt.Errorf("requires 1+ uppercase char (%s)", UppercaseChars)) } if !bytes.ContainsAny(password, NumberChars) { errs = append(errs, fmt.Errorf("requires 1+ number (%s)", NumberChars)) } if !bytes.ContainsAny(password, SpecialChars) { errs = append(errs, fmt.Errorf("requires 1+ special char (%s)", SpecialChars)) } return errors.Join(errs...) }