118 lines
2.6 KiB
Go
118 lines
2.6 KiB
Go
package types
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"github.com/ericlagergren/decimal"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// Decimal is a DECIMAL in sql. Its zero value is valid for use with both
|
|
// Value and Scan.
|
|
//
|
|
// Although decimal can represent NaN and Infinity it will return an error
|
|
// if an attempt to store these values in the database is made.
|
|
//
|
|
// Because it cannot be nil, when Big is nil Value() will return "0"
|
|
// It will error if an attempt to Scan() a "null" value into it.
|
|
type Decimal struct {
|
|
decimal.Big
|
|
}
|
|
|
|
// NullDecimal is the same as Decimal, but allows the Big pointer to be nil.
|
|
// See docmentation for Decimal for more details.
|
|
//
|
|
// When going into a database, if Big is nil it's value will be "null".
|
|
type NullDecimal struct {
|
|
*decimal.Big
|
|
}
|
|
|
|
// NewDecimal creates a new decimal from a decimal
|
|
func NewDecimal(d *decimal.Big) (num Decimal) {
|
|
if d != nil {
|
|
num.Big.Copy(d)
|
|
}
|
|
return
|
|
}
|
|
|
|
// NewNullDecimal creates a new null decimal from a decimal
|
|
func NewNullDecimal(d *decimal.Big) NullDecimal {
|
|
return NullDecimal{Big: d}
|
|
}
|
|
|
|
// Value implements driver.Valuer.
|
|
func (d Decimal) Value() (driver.Value, error) {
|
|
return decimalValue(&d.Big, false)
|
|
}
|
|
|
|
// Scan implements sql.Scanner.
|
|
func (d *Decimal) Scan(val any) (err error) {
|
|
_, err = decimalScan(&d.Big, val, false)
|
|
return
|
|
}
|
|
|
|
// Value implements driver.Valuer.
|
|
func (n NullDecimal) Value() (driver.Value, error) {
|
|
return decimalValue(n.Big, true)
|
|
}
|
|
|
|
// Scan implements sql.Scanner.
|
|
func (n *NullDecimal) Scan(val any) error {
|
|
newD, err := decimalScan(n.Big, val, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
n.Big = newD
|
|
return nil
|
|
}
|
|
|
|
func decimalValue(d *decimal.Big, canNull bool) (driver.Value, error) {
|
|
if canNull && d == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
if d.IsNaN(0) {
|
|
return nil, errors.New("refusing to allow NaN into database")
|
|
}
|
|
if d.IsInf(0) {
|
|
return nil, errors.New("refusing to allow infinity into database")
|
|
}
|
|
|
|
return d.String(), nil
|
|
}
|
|
|
|
func decimalScan(d *decimal.Big, val any, canNull bool) (*decimal.Big, error) {
|
|
if val == nil {
|
|
if !canNull {
|
|
return nil, errors.New("null cannot be scanned into decimal")
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
if d == nil {
|
|
d = new(decimal.Big)
|
|
}
|
|
|
|
switch t := val.(type) {
|
|
case float64:
|
|
d.SetFloat64(t)
|
|
return d, nil
|
|
case string:
|
|
if _, ok := d.SetString(t); !ok {
|
|
if err := d.Context.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, errors.Errorf("invalid decimal syntax: %q", t)
|
|
}
|
|
return d, nil
|
|
case []byte:
|
|
if err := d.UnmarshalText(t); err != nil {
|
|
return nil, err
|
|
}
|
|
return d, nil
|
|
default:
|
|
return nil, errors.Errorf("cannot scan decimal value: %#v", val)
|
|
}
|
|
}
|