pragma solidity ^0.8.24;
import "forge-std/Test.sol";
import "../src/briVault.sol";
import "forge-std/console.sol";
contract BriVaultDOSGasTest is Test {
MockERC20 public token;
address public owner;
uint256 public constant INITIAL_SUPPLY = 10_000_000 ether;
uint256 public constant EVENT_START = 1 days;
uint256 public constant EVENT_END = 8 days;
uint256 public constant PARTICIPATION_FEE = 100;
uint256 public constant MIN_AMOUNT = 1 ether;
function setUp() public {
owner = address(this);
token = new MockERC20("Test Token", "TTK", INITIAL_SUPPLY);
token.transfer(address(this), INITIAL_SUPPLY / 2);
}
function testSetWinnerGasScaling() public {
uint256[] memory userCounts = new uint256[](3);
userCounts[0] = 100;
userCounts[1] = 200;
userCounts[2] = 500;
uint256[] memory gasConsumptions = new uint256[](3);
for (uint i = 0; i < userCounts.length; i++) {
uint256 count = userCounts[i];
BriVault newVault = new BriVault(
IERC20(address(token)),
PARTICIPATION_FEE,
block.timestamp + EVENT_START,
owner,
MIN_AMOUNT,
block.timestamp + EVENT_END
);
string[48] memory countries;
for (uint j = 0; j < 48; j++) {
countries[j] = string(abi.encodePacked("Country", vm.toString(j)));
}
newVault.setCountry(countries);
vm.warp(block.timestamp + EVENT_START - 1 hours);
for (uint j = 0; j < count; j++) {
address user = address(uint160(3000 + j));
token.transfer(user, 10 ether);
vm.prank(user);
token.approve(address(newVault), 10 ether);
vm.prank(user);
newVault.deposit(2 ether, user);
vm.prank(user);
newVault.joinEvent(0);
}
vm.warp(block.timestamp + EVENT_END + 1 hours);
uint256 gasBefore = gasleft();
newVault.setWinner(0);
uint256 gasUsed = gasBefore - gasleft();
gasConsumptions[i] = gasUsed;
}
console.log("Gas consumption change when doubling users (100 to 200):", gasConsumptions[1] * 100 / gasConsumptions[0], "%");
console.log("Gas consumption change when 5x users (100 to 500):", gasConsumptions[2] * 100 / gasConsumptions[0], "%");
assertGt(gasConsumptions[1], gasConsumptions[0]);
assertGt(gasConsumptions[2], gasConsumptions[1]);
}
}
contract MockERC20 {
string public name;
string public symbol;
uint8 public decimals = 18;
uint256 public totalSupply;
mapping(address => uint256) public balanceOf;
mapping(address => mapping(address => uint256)) public allowance;
constructor(string memory _name, string memory _symbol, uint256 _initialSupply) {
name = _name;
symbol = _symbol;
totalSupply = _initialSupply;
balanceOf[msg.sender] = _initialSupply;
}
function transfer(address to, uint256 amount) public returns (bool) {
balanceOf[msg.sender] -= amount;
balanceOf[to] += amount;
return true;
}
function approve(address spender, uint256 amount) public returns (bool) {
allowance[msg.sender][spender] = amount;
return true;
}
function transferFrom(address from, address to, uint256 amount) public returns (bool) {
allowance[from][msg.sender] -= amount;
balanceOf[from] -= amount;
balanceOf[to] += amount;
return true;
}
}
// 1. Add the countryTotalShares mapping to the declaration section at the top of the contract.
+ mapping(uint256 => uint256) public countryTotalShares;
// 2. Modify the _getWinnerShares function to avoid traversing all users
- function _getWinnerShares () internal returns (uint256) {
- for (uint256 i = 0; i < usersAddress.length; ++i){
- address user = usersAddress[i];
- totalWinnerShares += userSharesToCountry[user][winnerCountryId];
- }
- return totalWinnerShares;
- }
+ function _getWinnerShares () internal returns (uint256) {
+ // Directly obtain the total share of the winning country from the pre-maintained mapping.
+ totalWinnerShares = countryTotalShares[winnerCountryId];
+ return totalWinnerShares;
+ }
// 3. Modify the `joinEvent` function to update the total share of the country when a user joins.
- function joinEvent(uint256 countryId) public {
- uint256 participantShares = balanceOf(msg.sender);
- userSharesToCountry[msg.sender][countryId] = participantShares;
+ function joinEvent(uint256 countryId) public {
+ uint256 participantShares = balanceOf(msg.sender);
+ userSharesToCountry[msg.sender][countryId] = participantShares;
+ countryTotalShares[countryId] += participantShares;
// 4. Modify the "cancelParticipation" function to update the total share of the country when the user cancels their participation.
- function cancelParticipation () public {
- uint256 shares = balanceOf(msg.sender);
+ function cancelParticipation () public {
+ uint256 shares = balanceOf(msg.sender);
+ // Search for the countries that the user has joined and update the total share for each country
+ for (uint256 i = 0; i < teams.length; ++i) {
+ if (userSharesToCountry[msg.sender][i] > 0) {
+ countryTotalShares[i] -= userSharesToCountry[msg.sender][i];
+ break;
+ }
+ }
+ _burn(msg.sender, shares);