go
/
misc
1
0
Fork 0
misc/types/big.go

118 lines
2.6 KiB
Go

package types
import (
"database/sql/driver"
"github.com/ericlagergren/decimal"
"github.com/pkg/errors"
)
// Decimal is a DECIMAL in sql. Its zero value is valid for use with both
// Value and Scan.
//
// Although decimal can represent NaN and Infinity it will return an error
// if an attempt to store these values in the database is made.
//
// Because it cannot be nil, when Big is nil Value() will return "0"
// It will error if an attempt to Scan() a "null" value into it.
type Decimal struct {
decimal.Big
}
// NullDecimal is the same as Decimal, but allows the Big pointer to be nil.
// See docmentation for Decimal for more details.
//
// When going into a database, if Big is nil it's value will be "null".
type NullDecimal struct {
*decimal.Big
}
// NewDecimal creates a new decimal from a decimal
func NewDecimal(d *decimal.Big) (num Decimal) {
if d != nil {
num.Big.Copy(d)
}
return
}
// NewNullDecimal creates a new null decimal from a decimal
func NewNullDecimal(d *decimal.Big) NullDecimal {
return NullDecimal{Big: d}
}
// Value implements driver.Valuer.
func (d Decimal) Value() (driver.Value, error) {
return decimalValue(&d.Big, false)
}
// Scan implements sql.Scanner.
func (d *Decimal) Scan(val any) (err error) {
_, err = decimalScan(&d.Big, val, false)
return
}
// Value implements driver.Valuer.
func (n NullDecimal) Value() (driver.Value, error) {
return decimalValue(n.Big, true)
}
// Scan implements sql.Scanner.
func (n *NullDecimal) Scan(val any) error {
newD, err := decimalScan(n.Big, val, true)
if err != nil {
return err
}
n.Big = newD
return nil
}
func decimalValue(d *decimal.Big, canNull bool) (driver.Value, error) {
if canNull && d == nil {
return nil, nil
}
if d.IsNaN(0) {
return nil, errors.New("refusing to allow NaN into database")
}
if d.IsInf(0) {
return nil, errors.New("refusing to allow infinity into database")
}
return d.String(), nil
}
func decimalScan(d *decimal.Big, val any, canNull bool) (*decimal.Big, error) {
if val == nil {
if !canNull {
return nil, errors.New("null cannot be scanned into decimal")
}
return nil, nil
}
if d == nil {
d = new(decimal.Big)
}
switch t := val.(type) {
case float64:
d.SetFloat64(t)
return d, nil
case string:
if _, ok := d.SetString(t); !ok {
if err := d.Context.Err(); err != nil {
return nil, err
}
return nil, errors.Errorf("invalid decimal syntax: %q", t)
}
return d, nil
case []byte:
if err := d.UnmarshalText(t); err != nil {
return nil, err
}
return d, nil
default:
return nil, errors.Errorf("cannot scan decimal value: %#v", val)
}
}