import { expect } from "chai";
import { ethers } from "hardhat";
import { Signer } from "ethers";
import { RToken, ERC20Mock } from "../typechain-types";
describe("RToken Burn Function Vulnerabilities", () => {
let rToken: RToken;
let asset: ERC20Mock;
let owner: Signer;
let user: Signer;
const INITIAL_INDEX = ethers.utils.parseUnits("1.0", 27);
const INTEREST_INDEX = ethers.utils.parseUnits("1.1", 27);
beforeEach(async () => {
[owner, user] = await ethers.getSigners();
const ERC20Mock = await ethers.getContractFactory("ERC20Mock");
asset = await ERC20Mock.deploy("Asset", "AST", 18);
const RToken = await ethers.getContractFactory("RToken");
rToken = await RToken.deploy(
"RToken",
"RTK",
await owner.getAddress(),
asset.address
);
await rToken.connect(owner).setReservePool(owner.getAddress());
await asset.mint(await user.getAddress(), ethers.utils.parseUnits("1000"));
await asset.connect(user).approve(rToken.address, ethers.constants.MaxUint256);
});
it("should demonstrate incorrect scaling leading to overburn", async () => {
await rToken.connect(owner).mint(
await owner.getAddress(),
await user.getAddress(),
ethers.utils.parseUnits("1000"),
INITIAL_INDEX
);
await rToken.connect(owner).updateLiquidityIndex(INTEREST_INDEX);
const burnAmount = ethers.utils.parseUnits("100");
const preBurnBalance = await rToken.balanceOf(await user.getAddress());
await rToken.connect(owner).burn(
await user.getAddress(),
await user.getAddress(),
burnAmount,
INTEREST_INDEX
);
const postBurnBalance = await rToken.balanceOf(await user.getAddress());
const expectedBurn = burnAmount.mul(ethers.BigNumber.from(10).pow(27)).div(INTEREST_INDEX);
expect(preBurnBalance.sub(postBurnBalance)).to.be.gt(expectedBurn);
});
it("should show missing interest accrual", async () => {
await rToken.connect(owner).mint(
await owner.getAddress(),
await user.getAddress(),
ethers.utils.parseUnits("1000"),
INITIAL_INDEX
);
await rToken.connect(owner).updateLiquidityIndex(INTEREST_INDEX);
const userBalance = await rToken.balanceOf(await user.getAddress());
await rToken.connect(owner).burn(
await user.getAddress(),
await user.getAddress(),
userBalance,
INTEREST_INDEX
);
const finalAssetBalance = await asset.balanceOf(await user.getAddress());
expect(finalAssetBalance).to.equal(ethers.utils.parseUnits("1000"));
});
});
function burn(
address from,
address receiverOfUnderlying,
uint256 amount,
uint256 index
) external override onlyReservePool returns (uint256, uint256, uint256) {
if (amount == 0) {
return (0, totalSupply(), 0);
}
uint256 scaledBalance = super.balanceOf(from);
uint256 balanceIncrease = 0;
if (_userState[from].index != 0 && _userState[from].index < index) {
balanceIncrease = scaledBalance.rayMul(index) - scaledBalance.rayMul(_userState[from].index);
}
_userState[from].index = index.toUint128();
uint256 amountScaled = amount.rayDiv(index);
if (amountScaled > scaledBalance) {
amountScaled = scaledBalance;
amount = amountScaled.rayMul(index);
}
_burn(from, amountScaled.toUint128());
if (receiverOfUnderlying != address(this)) {
IERC20(_assetAddress).safeTransfer(receiverOfUnderlying, amount + balanceIncrease);
}
emit Burn(from, receiverOfUnderlying, amount, index);
return (amountScaled, totalSupply(), amount + balanceIncrease);
}