Description
function _withdraw(
uint256 streamId,
address to,
uint128 amount
)
internal
returns (uint128 withdrawnAmount, uint128 protocolFeeAmount)
{
if (amount == 0) {
revert Errors.SablierFlow_WithdrawAmountZero(streamId);
}
if (to == address(0)) {
revert Errors.SablierFlow_WithdrawToZeroAddress(streamId);
}
if (to != _ownerOf(streamId) && !_isCallerStreamRecipientOrApproved(streamId)) {
revert Errors.SablierFlow_WithdrawalAddressNotRecipient({ streamId: streamId, caller: msg.sender, to: to });
}
uint8 tokenDecimals = _streams[streamId].tokenDecimals;
uint256 totalDebtScaled = _ongoingDebtScaledOf(streamId) + _streams[streamId].snapshotDebtScaled;
uint256 totalDebt = Helpers.descaleAmount(totalDebtScaled, tokenDecimals);
uint128 balance = _streams[streamId].balance;
uint128 withdrawableAmount;
if (balance < totalDebt) {
withdrawableAmount = balance;
} else {
withdrawableAmount = totalDebt.toUint128();
}
if (amount > withdrawableAmount) {
revert Errors.SablierFlow_Overdraw(streamId, amount, withdrawableAmount);
}
uint256 amountScaled = Helpers.scaleAmount(amount, tokenDecimals);
unchecked {
if (amountScaled <= _streams[streamId].snapshotDebtScaled) {
_streams[streamId].snapshotDebtScaled -= amountScaled;
}
else {
_streams[streamId].snapshotDebtScaled = totalDebtScaled - amountScaled;
_streams[streamId].snapshotTime = uint40(block.timestamp);
}
_streams[streamId].balance -= amount;
}
IERC20 token = _streams[streamId].token;
UD60x18 protocolFee = protocolFee[token];
if (protocolFee > ZERO) {
(protocolFeeAmount, amount) = Helpers.calculateAmountsFromFee({ totalAmount: amount, fee: protocolFee });
unchecked {
protocolRevenue[token] += protocolFeeAmount;
}
}
unchecked {
aggregateBalance[token] -= amount;
}
token.safeTransfer({ to: to, value: amount });
assert(totalDebt - _totalDebtOf(streamId) == balance - _streams[streamId].balance);
emit ISablierFlow.WithdrawFromFlowStream({
streamId: streamId,
to: to,
token: token,
caller: msg.sender,
withdrawAmount: amount,
protocolFeeAmount: protocolFeeAmount
});
return (amount, protocolFeeAmount);
}
In the Technical Doc, we have a variant
∑ stream balances + protocol revenue = aggregate balance.........................(i)
In the _withdraw
function
aggregateBalance
and _streams[streamId].balance
were deducted by the amount to be withdrawn respectively, if the protocolFee
> 0, protocolRevenue
value will be increased by a proctocolFee, thereby causing the left hand side of equation (i) > than the right hand side
Impact: **_withdraw
**could break one of the protocol invariants
Recommended mitigation
If the protocolFee > 0, the protocolFee should first be decucted from the amount to be withdrawn in a new variable say amountAfterProtocolFee
, then
function _withdraw(
uint256 streamId,
address to,
uint128 amount
)
internal
returns (uint128 withdrawnAmount, uint128 protocolFeeAmount)
{
if (amount == 0) {
revert Errors.SablierFlow_WithdrawAmountZero(streamId);
}
if (to == address(0)) {
revert Errors.SablierFlow_WithdrawToZeroAddress(streamId);
}
if (to != _ownerOf(streamId) && !_isCallerStreamRecipientOrApproved(streamId)) {
revert Errors.SablierFlow_WithdrawalAddressNotRecipient({ streamId: streamId, caller: msg.sender, to: to });
}
uint8 tokenDecimals = _streams[streamId].tokenDecimals;
uint256 totalDebtScaled = _ongoingDebtScaledOf(streamId) + _streams[streamId].snapshotDebtScaled;
uint256 totalDebt = Helpers.descaleAmount(totalDebtScaled, tokenDecimals);
uint128 balance = _streams[streamId].balance;
uint128 withdrawableAmount;
IERC20 token = _streams[streamId].token;
UD60x18 protocolFee = protocolFee[token];
if (balance < totalDebt) {
withdrawableAmount = balance;
} else {
withdrawableAmount = totalDebt.toUint128();
}
if (protocolFee > ZERO) {
(protocolFeeAmount, amount) = Helpers.calculateAmountsFromFee({ totalAmount: amount, fee: protocolFee });
unchecked {
protocolRevenue[token] += protocolFeeAmount;
}
}
uint128 amountAfterProtocolFee = amount - protocolFee;
if (amountAfterProtocolFee > withdrawableAmount) {
revert Errors.SablierFlow_Overdraw(streamId, amountAfterProtocolFee , withdrawableAmount);
}
uint256 amountScaled = Helpers.scaleAmount(amountAfterProtocolFee , tokenDecimals);
unchecked {
if (amountScaled <= _streams[streamId].snapshotDebtScaled) {
_streams[streamId].snapshotDebtScaled -= amountScaled;
}
else {
_streams[streamId].snapshotDebtScaled = totalDebtScaled - amountScaled;
_streams[streamId].snapshotTime = uint40(block.timestamp);
}
_streams[streamId].balance -= amountAfterProtocolFee ;
}
unchecked {
aggregateBalance[token] -= amountAfterProtocolFee ;
}
token.safeTransfer({ to: to, value:amountAfterProtocolFee });
assert(totalDebt - _totalDebtOf(streamId) == balance - _streams[streamId].balance);
emit ISablierFlow.WithdrawFromFlowStream({
streamId: streamId,
to: to,
token: token,
caller: msg.sender,
withdrawAmount: amount,
protocolFeeAmount: protocolFeeAmount
});
return (amountAfterProtocolFee , protocolFeeAmount);
}