使用 `Decimal`模块执行精确的浮点数运算
2023-01-14 本文已影响0人
Rethink
Python v3.10.6
背景
浮点数的一个普遍问题是它们并不能精确的表示十进制数。 并且,即使是最简单的数学运算也会产生小的误差,比如:
>>> 4.2+2.1
6.300000000000001
这些错误是由底层特征决定的,因此没办法去避免这样的误差。这时候可能首先想到的是使用内置的round(value, ndigits)
函数,但round
函数采用的舍入规则是四舍六入五成双,也就是说如果value刚好在两个边界的中间的时候, round 函数会返回离它最近的偶数(如下示例),除非对精确度没什么要求,否则尽量避开用此方法。
>>> round(1.5) == round(2.5) == 2
True
>>> round(2.675, 2)
2.67
解决办法
在一些对浮点数精度要求较高的领域,需要使用Decimal
模块中的方法来进行精准计算,官方文档:https://docs.python.org/zh-cn/3/library/decimal.html
基于Decimal
封装的工具类,部分源码如下:
from typing import Union
from decimal import ROUND_HALF_UP, Context, Decimal, getcontext
Numeric = Union[int, float, str, Decimal]
def isNumeric(n: Numeric, raise_exception=False) -> bool:
if isinstance(n, (int, float, Decimal)):
flag = True
elif isinstance(n, str) and n.replace(".", "").lstrip("+").lstrip("-").isdigit(): # not stirp
flag = True
else:
flag = False
if raise_exception and flag is False:
raise ValueError(f"Unsupport value type: {n}, type: {type(n)}")
return flag
class DecimalTool:
"""十进制浮点运算工具类.
计算结果与服务端decimal方法的计算结果保持完全一致, 方便在case中对数据进行精确断言.
同时实例方法支持链式调用, 降低python弱语言类型造成的困扰.
使用示例:
DecimalTool("0.348526").round("0.000").scale(3).toString()
>>> 348.226
"""
def __init__(self, n: Numeric, rounding: str = None, context: Context = None):
isNumeric(n, True)
ctx = context if context else getcontext()
self.n = Decimal(str(n))
# 设置舍入模式为四舍五入, 即若最后一个有效数字小于5则朝0方向取整,否则朝0反方向取整
self.rounding = ROUND_HALF_UP if not rounding else rounding
ctx.rounding = self.rounding
ctx.prec = 28
# init
self._cache = self.n
def round(self, exp: str):
"""将数值四舍五入到指定精度.
Usage:
>>> roundPrice1 = DecimalTool('3.14145').round('0.0000').toString()
>>> Assert.assert_equal(roundPrice1, '3.1415')
>>> Assert.assert_not_equal(roundPrice1, round(3.14145, 4))
"""
self._cache = self.n.quantize(Decimal(exp), self.rounding)
return self
def truncate(self, exp):
"""将数值截取到指定精度(不四舍五入)
Usage:
>>> DecimalTool('3.1415').truncate('0.000').toString()
3.141
"""
sourcePrecision = DecimalTool(self._cache).getPrecision()
targetPrecision = DecimalTool(exp).getPrecision()
if sourcePrecision > targetPrecision:
t = str(self._cache).split('.')
self._cache = Decimal(t[0] + "." + t[1][:targetPrecision])
return self
def scale(self, e: int):
"""以10为底, 将参数进行指数级别缩放(e可为负数)
Usage:
>>> DecimalTool('0.348526').scale(3).toString()
348.526
>>> DecimalTool('348.526').scale(-3).toString()
0.348526
"""
self._cache = self.n.scaleb(e)
return self
def cutdown(self, size: Numeric, mode: Literal['truncate', 'round'] = 'truncate'):
"""根据根据size, 将数值进行向下裁减处理, 并与size保持相同精度.
Uasge:
>>> DecimalTool("16000.6").cutdown('0.5').toString()
16000.5
>>> DecimalTool("16000.6").cutdown('0.50').toString()
16000.50
"""
isNumeric(size, True)
temp = Decimal(size) * int(self._cache // Decimal(size))
if mode == 'truncate':
self._cache = DecimalTool(temp).truncate(size).toDecimal()
elif mode == 'round':
self._cache = DecimalTool(temp).round(size).toDecimal()
else:
raise ValueError(f"Unsupport mode: {mode}")
return self
def isZero(self) -> bool:
"""如果参数为0, 则返回True, 否则返回False
Usage:
>>> DecimalTool('0.001').isZero()
Flase
>>> DecimalTool('0.00').isZero()
True
"""
return self._cache.is_zero()
def isSigend(self) -> bool:
"""如果参数带有负号,则返回为True, 否则返回False
特殊地, -0会返回Flase
"""
return self._cache.is_signed()
def deleteExtraZero(self):
"""删除小数点后面多余的0
"""
self._cache = Decimal(str(self._cache).rstrip('0').rstrip("."))
return self
def getPrecision(self) -> int:
"""获取小数位精度
"""
self._cache_str = str(self._cache).lower()
if "." in self._cache_str:
precision = len(self._cache_str.split('.')[-1])
elif "e-" in self._cache_str: # 某些情况下, 如当小数位数>6位时, 会变为科学计数法
_, precision = self._cache_str.replace('-', '').split('e')
precision = int(precision)
else:
precision = 0
return precision
def toDecimal(self) -> Decimal:
return self._cache
def toInt(self) -> int:
return self._cache.to_integral_value()
def toFloat(self) -> float:
return float(self._cache)
def toString(self) -> str:
return str(self._cache)
def toEngString(self) -> str:
return self._cache.to_eng_string()
Testcase (Pytest)
...
def test_decimal_tool():
precision1 = DecimalTool('0.348226000').getPrecision()
Assert.assert_equal(precision1, 9)
precision2 = DecimalTool('0.00000001').getPrecision()
Assert.assert_equal(precision2, 8)
precision3 = DecimalTool('123456').getPrecision()
Assert.assert_equal(precision3, 0)
roundPrice1 = DecimalTool('3.14145').round("0.0000").toString()
Assert.assert_equal(roundPrice1, '3.1415')
Assert.assert_not_equal(roundPrice1, round(3.14145, 4))
roundPrice2 = DecimalTool('0.348526').round("0.000").toString()
Assert.assert_equal(roundPrice2, '0.349')
truncatePrice1 = DecimalTool('0.0348').truncate("0.000").toString()
Assert.assert_equal(truncatePrice1, '0.034')
truncatePrice2 = DecimalTool('0.0348').truncate("0.0000000").toString()
Assert.assert_equal(truncatePrice2, '0.0348')
scaledPrice1 = DecimalTool('0.348526').round("0.000").scale(3).toString()
Assert.assert_equal(scaledPrice1, '348.526')
scaledPrice2 = DecimalTool('348.526').round("0.000").scale(-3).toString()
Assert.assert_equal(scaledPrice2, '0.348526')
cutdownPrice1 = DecimalTool("16000.6").cutdown('0.5').toString()
Assert.assert_equal(cutdownPrice1, '16000.5')
cutdownPrice1 = DecimalTool("16000.6").cutdown('0.50').toString()
Assert.assert_equal(cutdownPrice1, '16000.50') # 与size保持相同精度
isNotZero = DecimalTool('0.001').isZero()
Assert.assert_not_true(isNotZero)
isZero1 = DecimalTool('0.00').isZero()
Assert.assert_true(isZero1)
isZero2 = DecimalTool('0').isZero()
Assert.assert_true(isZero2)