function deposit(uint256 assets, address receiver) public override returns (uint256) {
require(receiver != address(0));
if (block.timestamp >= eventStartDate) {
revert eventStarted();
}
uint256 fee = _getParticipationFee(assets);
if (minimumAmount + fee > assets) {
revert lowFeeAndAmount();
}
uint256 stakeAsset = assets - fee;
@> stakedAsset[receiver] = stakeAsset;
uint256 participantShares = _convertToShares(stakeAsset);
IERC20(asset()).safeTransferFrom(msg.sender, participationFeeAddress, fee);
IERC20(asset()).safeTransferFrom(msg.sender, address(this), stakeAsset);
_mint(msg.sender, participantShares);
emit deposited (receiver, stakeAsset);
return participantShares;
}
function cancelParticipation () public {
if (block.timestamp >= eventStartDate){
revert eventStarted();
}
@> uint256 refundAmount = stakedAsset[msg.sender];
stakedAsset[msg.sender] = 0;
uint256 shares = balanceOf(msg.sender);
_burn(msg.sender, shares);
IERC20(asset()).safeTransfer(msg.sender, refundAmount);
}
pragma solidity ^0.8.24;
import {Test, console} from "forge-std/Test.sol";
import {BriVault} from "../src/briVault.sol";
import {BriTechToken} from "../src/briTechToken.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
contract StakedAssetOverwritePoC is Test {
BriVault public vault;
BriTechToken public token;
address owner = makeAddr("owner");
address user = makeAddr("user");
address feeAddress = makeAddr("feeAddress");
uint256 constant FIRST_DEPOSIT = 1000e18;
uint256 constant SECOND_DEPOSIT = 500e18;
function setUp() public {
vm.startPrank(owner);
token = new BriTechToken();
token.mint();
uint256 currentTime = block.timestamp;
vault = new BriVault(
IERC20(address(token)),
150,
currentTime + 1 days,
feeAddress,
1e18,
currentTime + 7 days
);
token.transfer(user, 10000e18);
vm.stopPrank();
}
function test_StakedAssetOverwriteOnMultipleDeposits() public {
console.log("=== Staked Asset Overwrite on Multiple Deposits PoC ===");
vm.startPrank(user);
uint256 firstFee = (FIRST_DEPOSIT * 150) / 10000;
uint256 secondFee = (SECOND_DEPOSIT * 150) / 10000;
token.approve(address(vault), FIRST_DEPOSIT + SECOND_DEPOSIT + firstFee + secondFee);
console.log("\n--- FIRST DEPOSIT ---");
console.log("User deposits:", FIRST_DEPOSIT / 1e18, "tokens");
vault.deposit(FIRST_DEPOSIT, user);
uint256 stakedAfterFirst = vault.stakedAsset(user);
console.log("stakedAsset[user] after first deposit:", stakedAfterFirst / 1e18, "tokens");
console.log("Expected:", (FIRST_DEPOSIT - firstFee) / 1e18, "tokens");
console.log("\n--- SECOND DEPOSIT ---");
console.log("User deposits:", SECOND_DEPOSIT / 1e18, "tokens");
vault.deposit(SECOND_DEPOSIT, user);
uint256 stakedAfterSecond = vault.stakedAsset(user);
console.log("stakedAsset[user] after second deposit:", stakedAfterSecond / 1e18, "tokens");
console.log("Expected if accumulated:", ((FIRST_DEPOSIT - firstFee) + (SECOND_DEPOSIT - secondFee)) / 1e18, "tokens");
console.log("\n--- VULNERABILITY DEMONSTRATION ---");
console.log("First deposit was OVERWRITTEN!");
console.log("Lost amount:", (stakedAfterFirst - stakedAfterSecond) / 1e18, "tokens");
console.log("\n--- CANCEL PARTICIPATION ---");
uint256 userBalanceBefore = token.balanceOf(user);
console.log("User balance before cancel:", userBalanceBefore / 1e18, "tokens");
vault.cancelParticipation();
uint256 userBalanceAfter = token.balanceOf(user);
uint256 refunded = userBalanceAfter - userBalanceBefore;
console.log("User balance after cancel:", userBalanceAfter / 1e18, "tokens");
console.log("Refunded amount:", refunded / 1e18, "tokens");
console.log("\n--- FUND LOSS CALCULATION ---");
uint256 totalDeposited = FIRST_DEPOSIT + SECOND_DEPOSIT;
uint256 totalFees = firstFee + secondFee;
uint256 expectedRefund = totalDeposited - totalFees;
console.log("Total deposited:", totalDeposited / 1e18, "tokens");
console.log("Total fees deducted:", totalFees / 1e18, "tokens");
console.log("Expected refund:", expectedRefund / 1e18, "tokens");
console.log("Actual refund:", refunded / 1e18, "tokens");
console.log("FUNDS LOST:", (expectedRefund - refunded) / 1e18, "tokens");
assertLt(refunded, expectedRefund, "User should receive less due to overwrite");
assertEq(refunded, stakedAfterSecond, "Refund equals only the last deposit amount");
vm.stopPrank();
}
}
forge test --match-contract StakedAssetOverwritePoC -vv
[⠒] Compiling...
No files changed, compilation skipped
Ran 1 test for test/StakedAssetOverwritePoC.t.sol:StakedAssetOverwritePoC
[PASS] test_StakedAssetOverwriteOnMultipleDeposits() (gas: 202128)
Logs:
=== Staked Asset Overwrite on Multiple Deposits PoC ===
--- FIRST DEPOSIT ---
User deposits: 1000 tokens
stakedAsset[user] after first deposit: 985 tokens
Expected: 985 tokens
--- SECOND DEPOSIT ---
User deposits: 500 tokens
stakedAsset[user] after second deposit: 492 tokens
Expected if accumulated: 1477 tokens
--- VULNERABILITY DEMONSTRATION ---
First deposit was OVERWRITTEN!
Lost amount: 492 tokens
--- CANCEL PARTICIPATION ---
User balance before cancel: 8500 tokens
User balance after cancel: 8992 tokens
Refunded amount: 492 tokens
--- FUND LOSS CALCULATION ---
Total deposited: 1500 tokens
Total fees deducted: 22 tokens
Expected refund: 1477 tokens
Actual refund: 492 tokens
FUNDS LOST: 985 tokens
Suite result: ok. 1 passed; 0 failed; 0 skipped; finished in 3.54ms (691.63µs CPU time)
Ran 1 test suite in 40.49ms (3.54ms CPU time): 1 tests passed, 0 failed, 0 skipped (1 total tests)
ensures all deposits are properly tracked and users receive complete refunds reflecting their total contribution to the vault.
function deposit(uint256 assets, address receiver) public override returns (uint256) {
require(receiver != address(0));
if (block.timestamp >= eventStartDate) {
revert eventStarted();
}
uint256 fee = _getParticipationFee(assets);
if (minimumAmount + fee > assets) {
revert lowFeeAndAmount();
}
uint256 stakeAsset = assets - fee;
- stakedAsset[receiver] = stakeAsset;
+ stakedAsset[receiver] += stakeAsset; // Accumulate instead of overwrite
uint256 participantShares = _convertToShares(stakeAsset);
IERC20(asset()).safeTransferFrom(msg.sender, participationFeeAddress, fee);
IERC20(asset()).safeTransferFrom(msg.sender, address(this), stakeAsset);
_mint(msg.sender, participantShares);
emit deposited (receiver, stakeAsset);
return participantShares;
}