function _getMartenitsas(uint256 amount, uint256 tokenIdToStart) internal returns (uint256) {
uint256 tokenId = tokenIdToStart;
vm.startPrank(chasy);
for(uint256 i = 0; i < amount; i++){
martenitsaToken.createMartenitsa(string(abi.encodePacked("bracelet-", i)));
marketplace.listMartenitsaForSale(tokenId, 1 wei);
martenitsaToken.approve(address(marketplace), tokenId);
marketplace.makePresent(bob, tokenId);
tokenId++;
}
vm.stopPrank();
return tokenId;
}
function test__CollectReward__ImproperTrackingOfPreviousRewards() public {
assertEq(healthToken.balanceOf(bob), 0);
assertEq(martenitsaToken.balanceOf(bob), 0);
uint256[] memory martenitsasCountToAdd = new uint256[](3);
martenitsasCountToAdd[0] = 6;
martenitsasCountToAdd[1] = 3;
martenitsasCountToAdd[2] = 9;
uint256 tokenId = 0;
uint256 totalCount = 0;
for(uint256 i = 0; i < martenitsasCountToAdd.length; i++){
uint256 currCountToAdd = martenitsasCountToAdd[i];
totalCount += currCountToAdd;
tokenId = _getMartenitsas(currCountToAdd, tokenId);
assertEq(martenitsaToken.balanceOf(bob), totalCount);
vm.prank(bob);
marketplace.collectReward();
}
assertEq(healthToken.balanceOf(bob), (totalCount / 3) * 10 ** 18);
* EXPLANATION:
*
* In case of `_collectedRewards[msg.sender] = amountRewards;` (vulnerable code)
* 1. Bob has 6 MartenitsaToken
* 1.a. Bob collects rewards
* 1.b. `_collectedRewards[msg.sender] = 0`, amountRewards = 2 - 0 = 2
* 1.c. --> _collectedRewards[msg.sender] = 2
*
* 2. Bob gets 3 more MartenitsaToken (total=9)
* 2.a. Bob collects rewards
* 2.b. `_collectedRewards[msg.sender] = 2`, amountRewards = 3 - 2 = 1
* 2.c. --> _collectedRewards[msg.sender] = 1 (*issue* - should be `+= 1`, so it would be 3)
*
* EXPLOIT
* 3. Bob gets 9 more MartenitsaToken (total=18)
* 3.a. Bob collects rewards
* 3.b. `_collectedRewards[msg.sender] = 1`, amountRewards = 6 - 1 = 5
* 3.c. --> _collectedRewards[msg.sender] = 5
*
* --------------> healthToken.balanceOf(bob) = 8 (incorrect)
* In case of `_collectedRewards[msg.sender] += amountRewards;` (fixed code)
* 1. Bob has 6 MartenitsaToken
* 1.a. Bob collects rewards
* 1.b. `_collectedRewards[msg.sender] = 0`, amountRewards = 2 - 0 = 2
* 1.c. --> _collectedRewards[msg.sender] += 2
* 1.d. --> _collectedRewards[msg.sender] == 2
*
* 2. Bob gets 3 more MartenitsaToken (total=9)
* 2.a. Bob collects rewards
* 2.b. `_collectedRewards[msg.sender] = 2`, amountRewards = 3 - 2 = 1
* 2.c. --> _collectedRewards[msg.sender] += 1
* 2.d. --> _collectedRewards[msg.sender] == 3
*
* 3. Bob gets 9 more MartenitsaToken (total=18)
* 3.a. Bob collects rewards
* 3.b. `_collectedRewards[msg.sender] = 3`, amountRewards = 6 - 3 = 3
* 3.c. --> _collectedRewards[msg.sender] += 3
* 3.d. --> _collectedRewards[msg.sender] == 6
*
* --------------> healthToken.balanceOf(bob) = 6 (correct)
*/
}