Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multichain support for decoding #70

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,24 @@ async def get_or_create(

@classmethod
async def get_abi_by_contract_address(
cls, session: AsyncSession, address: bytes
cls, session: AsyncSession, address: bytes, chain_id: int | None
) -> ABI | None:
# TODO Add chain_id filter to support multichain
results = await session.exec(
"""
:return: Json ABI given the contract `address` and `chain_id`. If `chain_id` is not given,
sort the ABIs by `chain_id` and return the first one.
"""
query = (
select(Abi.abi_json)
.join(cls)
.where(cls.address == address)
.where(cls.abi_id == Abi.id)
)
if chain_id is not None:
query = query.where(cls.chain_id == chain_id)
else:
query = query.order_by(col(cls.chain_id))

results = await session.exec(query)
if result := results.first():
return cast(ABI, result)
return None
Expand Down
109 changes: 80 additions & 29 deletions app/services/data_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class DataDecoderService:
dummy_w3 = Web3()
session: AsyncSession | None

fn_selectors_with_abis: dict[bytes, ABIFunction]
multisend_abis: list[ABI]
multisend_fn_selectors_with_abis: dict[bytes, ABIFunction]

async def init(self, session: AsyncSession) -> None:
"""
Initialize the data decoder service, loading the ABIs from the database and storing the 4byte selectors
Expand Down Expand Up @@ -129,27 +133,38 @@ async def get_multisend_abis(self) -> AsyncIterator[ABI]:
@alru_cache(maxsize=2048)
@database_session
async def get_contract_abi(
self, address: Address, session: AsyncSession | None = None
self,
address: Address,
chain_id: int | None,
session: AsyncSession | None = None,
) -> ABI | None:
"""
Retrieves the ABI for the contract at the given address.

:param address: Contract address
:param chain_id: Chain id for the contract
:param session: Database session, provided by the decorator
:return: List of ABI data if found, `None` otherwise.
"""
assert session is not None
return await Contract.get_abi_by_contract_address(session, HexBytes(address))
return await Contract.get_abi_by_contract_address(
session, HexBytes(address), chain_id
)

@alru_cache(maxsize=2048)
async def get_contract_fallback_function(
self, address: Address
self, address: Address, chain_id: int | None
) -> ABIFunction | None:
"""
:param address: Contract address
:return: Fallback ABIFunction if found, `None` otherwise.
"""
abi = await self.get_contract_abi(address)
:param chain_id: Chain for the contract
:return: Fallback `ABIFunction` if found, `None` otherwise.
If contract is not found for the chain, return the first one that matches in other chain.
"""
abi = await self.get_contract_abi(address, chain_id)
if not abi and chain_id is not None:
# Try to find an ABI in other network
abi = await self.get_contract_abi(address, None)
if abi:
return next(
(
Expand All @@ -163,25 +178,29 @@ async def get_contract_fallback_function(

@alru_cache(maxsize=2048)
async def get_contract_abi_selectors_with_functions(
self, address: Address
self, address: Address, chain_id: int | None
) -> dict[bytes, ABIFunction] | None:
"""
:param address: Contract address
:return: Dictionary of function selects with ABIFunction if found, `None` otherwise
"""
abi = await self.get_contract_abi(address)
:param chain_id: Chain for the contract
:return: Dictionary of function selects with `ABIFunction` if found, `None` otherwise
If contract is not found for the chain, return the first one that matches in other chain.
"""
abi = await self.get_contract_abi(address, chain_id)
if not abi and chain_id is not None:
# Try to find an ABI in other network
abi = await self.get_contract_abi(address, None)
if abi:
# TODO We should return that there's a fullMatch for this `data` and `address`, so we are sure
# we are decoding the `data` correctly
return self._generate_selectors_with_abis_from_abi(abi)
return None

async def get_abi_function(
self, data: bytes, address: Address | None = None
self, data: bytes, address: Address | None = None, chain_id: int | None = None
) -> ABIFunction | None:
"""
:param data: transaction data
:param address: contract address in case of ABI colliding
:param chain_id: Chain for the contract
:return: Abi function for data if it can be decoded, `None` if not found
"""
selector = data[:4]
Expand All @@ -190,7 +209,9 @@ async def get_abi_function(
# Try to use specific ABI if address provided
if address:
contract_selectors_with_abis = (
await self.get_contract_abi_selectors_with_functions(address)
await self.get_contract_abi_selectors_with_functions(
address, chain_id
)
)
if (
contract_selectors_with_abis
Expand All @@ -202,7 +223,7 @@ async def get_abi_function(
return self.fn_selectors_with_abis[selector]
# Check if the contract has a fallback call and return a minimal ABIFunction for fallback call
elif address:
return await self.get_contract_fallback_function(address)
return await self.get_contract_fallback_function(address, chain_id)
return None

def _parse_decoded_arguments(self, value_decoded: Any) -> Any:
Expand All @@ -220,13 +241,17 @@ def _parse_decoded_arguments(self, value_decoded: Any) -> Any:
return value_decoded

async def _decode_data(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> tuple[str, list[tuple[str, str, Any]]]:
"""
Decode tx data

:param data: Tx data as `hex string` or `bytes`
:param address: contract address in case of ABI colliding
:param chain_id: Chain for the contract
:return: Tuple with the `function name` and a List of sorted tuples with
the `name` of the argument, `type` and `value`
:raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function
Expand All @@ -238,7 +263,7 @@ async def _decode_data(

data = HexBytes(data)
params = data[4:]
fn_abi = await self.get_abi_function(data, address)
fn_abi = await self.get_abi_function(data, address, chain_id)
if not fn_abi:
raise CannotDecode(data.hex())
try:
Expand All @@ -254,12 +279,13 @@ async def _decode_data(
return fn_abi["name"], list(zip(names, types, values))

async def decode_multisend_data(
self, data: bytes | str
self, data: bytes | str, chain_id: int | None = None
) -> list[MultisendDecoded] | None:
"""
Decodes Multisend raw data to Multisend dictionary

:param data:
:param chain_id:
:return:
"""
try:
Expand All @@ -271,7 +297,9 @@ async def decode_multisend_data(
value=str(multisend_tx.value),
data=HexStr(multisend_tx.data.hex()) if multisend_tx.data else None,
data_decoded=await self.get_data_decoded(
multisend_tx.data, address=cast(Address, multisend_tx.to)
multisend_tx.data,
address=cast(Address, multisend_tx.to),
chain_id=chain_id,
),
)
for multisend_tx in multisend_txs
Expand All @@ -285,27 +313,34 @@ async def decode_multisend_data(
return None

async def get_data_decoded(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> DataDecoded | None:
"""
Return data prepared for serializing

:param data:
:param address: contract address in case of ABI colliding
:param chain_id: chain for contract
:return:
"""
if not data:
return None
try:
fn_name, parameters = await self.decode_transaction_with_types(
data, address=address
data, address=address, chain_id=chain_id
)
return {"method": fn_name, "parameters": parameters}
except DataDecoderException:
return None

async def decode_parameters_data(
self, data: bytes, parameters: list[ParameterDecoded]
self,
data: bytes,
parameters: list[ParameterDecoded],
chain_id: int | None = None,
) -> list[ParameterDecoded]:
"""
Decode inner data for function parameters for:
Expand All @@ -319,7 +354,9 @@ async def decode_parameters_data(
fn_selector = data[:4]
if fn_selector in self.multisend_fn_selectors_with_abis:
# If MultiSend, decode the transactions
parameters[0]["value_decoded"] = await self.decode_multisend_data(data)
parameters[0]["value_decoded"] = await self.decode_multisend_data(
data, chain_id=chain_id
)

elif (
fn_selector == self.EXEC_TRANSACTION_SELECTOR
Expand All @@ -331,49 +368,63 @@ async def decode_parameters_data(
# selector is `0x6a761202` and parameters[2] is data
try:
parameters[2]["value_decoded"] = await self.get_data_decoded(
data, address=parameters[0]["value"]
data, address=parameters[0]["value"], chain_id=chain_id
)
except DataDecoderException:
logger.warning("Cannot decode `execTransaction`", exc_info=True)
return parameters

async def decode_transaction_with_types(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> tuple[str, list[ParameterDecoded]]:
"""
Decode tx data and return a list of dictionaries

:param data: Tx data as `hex string` or `bytes`
:param address: contract address in case of ABI colliding
:param chain_id: chain for the contract
:return: Tuple with the `function name` and a list of dictionaries
[{'name': str, 'type': str, 'value': `depending on type`}...]
:raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function
:raises: UnexpectedProblemDecoding if there's an unexpected problem decoding (it shouldn't happen)
"""
data = HexBytes(data)
fn_name, raw_parameters = await self._decode_data(data, address=address)
fn_name, raw_parameters = await self._decode_data(
data, address=address, chain_id=chain_id
)
# Parameters are returned as tuple, convert it to a dictionary
parameters = [
ParameterDecoded(name=name, type=argument_type, value=value)
for name, argument_type, value in raw_parameters
]
nested_parameters = await self.decode_parameters_data(data, parameters)
nested_parameters = await self.decode_parameters_data(
data, parameters, chain_id=chain_id
)
return fn_name, nested_parameters

async def decode_transaction(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> tuple[str, dict[str, Any]]:
"""
Decode tx data and return all the parameters in the same dictionary

:param data: Tx data as `hex string` or `bytes`
:param address: contract address in case of ABI colliding
:param chain_id: chain for the contract
:return: Tuple with the `function name` and a dictionary with the arguments of the function
:raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function
:raises: UnexpectedProblemDecoding if there's an unexpected problem decoding (it shouldn't happen)
"""
fn_name, decoded_transactions_with_types = (
await self.decode_transaction_with_types(data, address=address)
await self.decode_transaction_with_types(
data, address=address, chain_id=chain_id
)
)
decoded_transactions = {
d["name"]: d["value"] for d in decoded_transactions_with_types
Expand Down
21 changes: 19 additions & 2 deletions app/tests/datasources/db/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,27 @@ async def test_contract_get_abi_by_contract_address(self, session: AsyncSession)
await abi.create(session)
contract = Contract(address=b"a", name="A test contract", chain_id=1, abi=abi)
await contract.create(session)
result = await contract.get_abi_by_contract_address(session, contract.address)
result = await contract.get_abi_by_contract_address(
session, contract.address, 1
)
self.assertEqual(result, abi_json)

self.assertIsNone(await contract.get_abi_by_contract_address(session, b"b"))
# Check chain_id not matching
result = await contract.get_abi_by_contract_address(
session, contract.address, 2
)
self.assertIsNone(result)

# Ignoring chain_id
result = await contract.get_abi_by_contract_address(
session, contract.address, None
)
self.assertEqual(result, abi_json)

# Check address not matching
self.assertIsNone(
await contract.get_abi_by_contract_address(session, b"b", None)
)

@database_session
async def test_project(self, session: AsyncSession):
Expand Down
Loading
Loading