Michael.W基于Foundry精读Openzeppelin第3期——Arrays.sol

  • Michael.W
  • 更新于 2024-01-02 14:53
  • 阅读 1916

从foundry工程化的角度详细解读Openzeppelin中的Arrays库及对应测试。

0. 版本

[openzeppelin]:v4.8.3,[forge-std]:v1.5.6

0.1 Arrays.sol

Github: https://github.com/OpenZeppelin/openzeppelin-contracts/blob/v4.8.3/contracts/utils/Arrays.sol

Arrays库是一个专门作用于uint256[] storage / address[] storage / bytes32[] storage的工具库。

1. 补充:关于storage的定长数组和动态数组的layout

直接观察demo合约的slot排布:

contract ArrayLayoutChecker {
    // slot0
    uint public a = 2;
    // slot1
    uint[] public arr = [0xdddd, 0xeeee, 0xffff];
    // slot2
    address public addr = address(1024);
    // slot3 ~ slot6
    address[4] public addrs = [address(0xa), address(0xb), address(0xc), address(0xd)];
}

说明:

  • storage的定长数组address[4] public addrs的本位slot为3,那么其元素依次占据slot3~ slot6

  • storage的动态数组uint[] public arr的本位slot为2,那么slot2中仅存储该动态数组的长度。其元素将按照以下算法依次存储于对应slot中:bytes32(uint(keccake256(动态数组本位slot号)) + 元素索引)

可见storage的动态数组和定长数组的元素都是按数组内顺序依次存储于slot之中,只是起始的slot号不一样。动态数组的本位slot存储的是动态数组的长度,而定长数组的本位slot存储的是第一个元素。

foundry代码验证

contract ArraysTest is Test {
    ArrayLayoutChecker alc = new ArrayLayoutChecker();

    function test_LayoutForDynamicAndStaticArrays() external {
        // 向动态数组内增添新的元素
        alc.pushArr(0xabcd);

        // 通过slot号读取对应slot中存储的值
        // slot0: 状态变量a的值——2
        uint valueSlot0 = uint(vm.load(address(alc), bytes32(0)));
        assertEq(alc.a(), valueSlot0);
        // slot1: 存放的是动态数组arr中的元素数量,即arr.length
        uint valueSlot1 = uint(vm.load(address(alc), bytes32(uint(1))));
        assertEq(alc.getArrLength(), valueSlot1);
        // slot2: 状态变量addr的值——address(1024)
        address valueSlot2 = address(uint160(uint(vm.load(address(alc), bytes32(uint(2))))));
        assertEq(alc.addr(), valueSlot2);
        // slot3~slot6: 静态数组address[4] addrs 按顺序排布的四个元素
        address valueSlot3 = address(uint160(uint(vm.load(address(alc), bytes32(uint(3))))));
        address valueSlot4 = address(uint160(uint(vm.load(address(alc), bytes32(uint(4))))));
        address valueSlot5 = address(uint160(uint(vm.load(address(alc), bytes32(uint(5))))));
        address valueSlot6 = address(uint160(uint(vm.load(address(alc), bytes32(uint(6))))));
        assertEq(alc.addrs(0), valueSlot3);
        assertEq(alc.addrs(1), valueSlot4);
        assertEq(alc.addrs(2), valueSlot5);
        assertEq(alc.addrs(3), valueSlot6);

        // 动态数组的元素存储的slot号:keccak256(动态数组本位的slot号) + 索引值
        // 本案例中动态数组的本位slot为slot1,即本位slot号为bytes32(uint(1))
        bytes32 startSlot = keccak256(abi.encodePacked(uint(1)));
        // 动态数组的第1个元素的slot号,即startSlotNumber + 0
        assertEq(alc.arr(0), uint(vm.load(address(alc), bytes32(uint(startSlot) + 0))));
        // 动态数组的第2个元素的slot号,即startSlotNumber + 1
        assertEq(alc.arr(1), uint(vm.load(address(alc), bytes32(uint(startSlot) + 1))));
        // 动态数组的第3个元素的slot号,即startSlotNumber + 2
        assertEq(alc.arr(2), uint(vm.load(address(alc), bytes32(uint(startSlot) + 2))));
        // 动态数组的第4个元素的slot号,即startSlotNumber + 3
        assertEq(alc.arr(3), uint(vm.load(address(alc), bytes32(uint(startSlot) + 3))));
        // 注: 动态数组和静态数组的元素在slot中都是按照顺序依次紧密地向后存储在slot中
    }
}

contract ArrayLayoutChecker {
    uint public a = 2;
    uint[] public arr = [0xdddd, 0xeeee, 0xffff];
    address public addr = address(1024);
    address[4] public addrs = [address(0xa), address(0xb), address(0xc), address(0xd)];

    function pushArr(uint v) external {
        arr.push(v);
    }

    function getArrLength() external view returns (uint){
        return arr.length;
    }
}

2. 目标合约

封装Arrays library成为一个可调用合约:

Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/src/utils/MockArrays.sol

// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "openzeppelin-contracts/contracts/utils/Arrays.sol";

contract MockArrays {
    using Arrays for uint[];
    using Arrays for bytes32[];
    using Arrays for address[];

    uint[] public arrUint = [1, 2, 11, 19, 21, 22, 100, 201, 224, 999];
    bytes32[] public arrBytes32 = [bytes32('a'), bytes32('b'), bytes32('c'), bytes32('d'), bytes32('e')];
    address[] public arrAddress = [address(0xff), address(0xee), address(0xdd), address(0xcc), address(0xbb), address(0xaa)];

    function findUpperBound(uint element) external view returns (uint){
        return arrUint.findUpperBound(element);
    }

    function unsafeAccessUintArrays(uint pos) external view returns (uint){
        return arrUint.unsafeAccess(pos).value;
    }

    function unsafeAccessBytes32Arrays(uint pos) external view returns (bytes32){
        return arrBytes32.unsafeAccess(pos).value;
    }

    function unsafeAccessAddressArrays(uint pos) external view returns (address){
        return arrAddress.unsafeAccess(pos).value;
    }

    function clearArrUint() external {
        delete arrUint;
    }

    function addArrUint(uint element) external {
        arrUint.push(element);
    }

    function getLength(uint slotNumber) external view returns (uint){
        if (slotNumber == 0) {
            return arrUint.length;
        } else if (slotNumber == 1) {
            return arrBytes32.length;
        } else if (slotNumber == 2) {
            return arrAddress.length;
        } else {
            return 0;
        }
    }
}

全部foundry测试合约:

Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/test/utils/Arrays.t.sol

3. 代码精读

3.1 unsafeAccess(address[] storage, uint256)

返回动态address数组中指定索引的元素值。该方法节约gas,但是不进行数组索引越界检查。所以只有当你确定你要取的索引值小于动态address数组长度时才去使用该方法。

function unsafeAccess(address[] storage arr, uint256 pos) internal pure returns (StorageSlot.AddressSlot storage) {
    // 声明一个slot变量,用于计算arr[pos]的slot号
        bytes32 slot;
        assembly {
            // 在memory的0~32字节的位置存储动态数组arr的slot号
            mstore(0, arr.slot)
            // keccak256(0, 0x20): 将memory中0~32字节的内容(即动态数组arr的slot号)求keccak256
            // slot赋值为arr的slot号的hash值+偏移值pos的和,即在layout中存储的arr[pos]的slot号
            slot := add(keccak256(0, 0x20), pos)
        }
        // 直接用arr[pos]的slot号从storage中取出值
        return slot.getAddressSlot();
}

foundry代码验证

contract ArraysTest is Test {
    MockArrays ma = new MockArrays();

    function test_UnsafeAccess() external {
        uint l = ma.getLength(0);
        for (uint i = 0; i < l; ++i) {
            assertEq(ma.arrUint(i), ma.unsafeAccessUintArrays(i));
        }

        // revert if out of index with []
        vm.expectRevert();
        ma.arrUint(l);
        // not revert with unsafeAccess(), but get zero value
        assertEq(0, ma.unsafeAccessUintArrays(l));

        l = ma.getLength(1);
        for (uint i = 0; i < l; ++i) {
            assertEq(ma.arrBytes32(i), ma.unsafeAccessBytes32Arrays(i));
        }

        // revert if out of index with []
        vm.expectRevert();
        ma.arrBytes32(l);
        // not revert with unsafeAccess(), but get zero value
        assertEq(0, ma.unsafeAccessBytes32Arrays(l));

        l = ma.getLength(2);
        for (uint i = 0; i < l; ++i) {
            assertEq(ma.arrAddress(i), ma.unsafeAccessAddressArrays(i));
        }

        // revert if out of index with []
        vm.expectRevert();
        ma.arrAddress(l);
        // not revert with unsafeAccess(), but get zero value
        assertEq(address(0), ma.unsafeAccessAddressArrays(l));
    }
}

3.2 unsafeAccess(bytes32[] storage, uint256)

返回动态bytes32数组中指定索引的元素值。该方法节约gas,但是不进行数组索引越界检查。所以只有当你确定你要取的索引值小于动态bytes32数组长度时才去使用该方法。

function unsafeAccess(bytes32[] storage arr, uint256 pos) internal pure returns (StorageSlot.Bytes32Slot storage) {
    // 声明一个slot变量,用于计算arr[pos]的slot号
    bytes32 slot;
    // 在memory的0~32字节的位置存储动态数组arr的slot号
    assembly {
        // 在memory的0~32字节的位置存储动态数组arr的slot号
        mstore(0, arr.slot)
        // keccak256(0, 0x20): 将memory中0~32字节的内容(即动态数组arr的slot号)求keccak256
        // slot赋值为arr的slot号的hash值+偏移值pos的和,即在layout中存储的arr[pos]的slot号
        slot := add(keccak256(0, 0x20), pos)
    }
    // 直接用arr[pos]的slot号从storage中取出值
    return slot.getBytes32Slot();
}

foundry代码验证:见3.1

3.3 unsafeAccess(uint256[] storage, uint256)

返回动态uint256数组中指定索引的元素值。该方法节约gas,但是不进行数组索引越界检查。所以只有当你确定你要取的索引值小于动态uint256数组长度时才去使用该方法。

function unsafeAccess(uint256[] storage arr, uint256 pos) internal pure returns (StorageSlot.Uint256Slot storage) {
    // 声明一个slot变量,用于计算arr[pos]的slot号
    bytes32 slot;
    // 在memory的0~32字节的位置存储动态数组arr的slot号
    assembly {
        // 在memory的0~32字节的位置存储动态数组arr的slot号
        mstore(0, arr.slot)
        // keccak256(0, 0x20): 将memory中0~32字节的内容(即动态数组arr的slot号)求keccak256
        // slot赋值为arr的slot号的hash值+偏移值pos的和,即在layout中存储的arr[pos]的slot号
        slot := add(keccak256(0, 0x20), pos)
    }
    // 直接用arr[pos]的slot号从storage中取出值
    return slot.getUint256Slot();
}

foundry代码验证:见3.1

3.4 findUpperBound(uint256[] storage array, uint256 element)

从一个排序好的数组array中,返回第一个大于或等于element的元素的索引值。如果整个数组array中都没有符合条件的元素,则返回整个数组array的长度。

这个操作的时间复杂度为O(log n)

前提条件array为升序排列且其中没有重复的元素值

function findUpperBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
    // 如果是空数组,返回0
    if (array.length == 0) {
        return 0;
    }
    // 两个边界flag
    uint256 low = 0;
    uint256 high = array.length;

    // 开始二分法查找,直到low>=high时停止
    while (low < high) {
        // 如果low<high,mid为low和high的均值
        uint256 mid = Math.average(low, high);
        // 注:Math.average()如果均值为小数,则向下取整
        if (unsafeAccess(array, mid).value > element) {
            // 如果索引为mid的元素大于目标值element,缩小范围:令high=mid
            high = mid;
        } else {
            // 如果索引为mid的元素小于等于目标值element,缩小范围:令low=mid+1
            low = mid + 1;
        }
    }

    // 如果此时low>0,说明此时的low已经是唯一的上界(因为low只有大于等于high才会跳出循环)
    if (low > 0 && unsafeAccess(array, low - 1).value == element) {
        // 如果array[low-1]等于目标值element,则直接返回索引low-1,即等于目标值的索引
        return low - 1;
    } else {
        // 如果array[low-1]不等于目标值element,说明array中所有元素都小于目标值element,那么直接返回low(即array的长度)
        return low;
    }
}

foundry代码验证

contract ArraysTest is Test {
    MockArrays ma = new MockArrays();

    // 目标数组元素个数为偶数
    function test_FindUpperBound_WithEvenLength() external {
        // arrUint: [1, 2, 11, 19, 21, 22, 100, 201, 224, 999]
        assertEq(ma.getLength(0), 10);
        assertEq(0, ma.findUpperBound(0));
        assertEq(0, ma.findUpperBound(1));
        assertEq(1, ma.findUpperBound(2));
        assertEq(2, ma.findUpperBound(3));
        assertEq(2, ma.findUpperBound(10));
        assertEq(2, ma.findUpperBound(11));
        assertEq(3, ma.findUpperBound(12));
        assertEq(3, ma.findUpperBound(19));
        assertEq(4, ma.findUpperBound(21));
        assertEq(5, ma.findUpperBound(22));
        assertEq(6, ma.findUpperBound(100));
        assertEq(7, ma.findUpperBound(201));
        assertEq(8, ma.findUpperBound(224));
        assertEq(9, ma.findUpperBound(999));
        // greater than all elements in the array, it will return the length of the array
        assertEq(10, ma.findUpperBound(1000));
    }

    // 目标数组元素个数为奇数
    function test_FindUpperBound_WithOddLength() external {
        ma.addArrUint(2000);
        // arrUint: [1, 2, 11, 19, 21, 22, 100, 201, 224, 999, 2000]
        assertEq(ma.getLength(0), 11);
        assertEq(0, ma.findUpperBound(0));
        assertEq(0, ma.findUpperBound(1));
        assertEq(1, ma.findUpperBound(2));
        assertEq(2, ma.findUpperBound(3));
        assertEq(2, ma.findUpperBound(10));
        assertEq(2, ma.findUpperBound(11));
        assertEq(3, ma.findUpperBound(12));
        assertEq(3, ma.findUpperBound(19));
        assertEq(4, ma.findUpperBound(21));
        assertEq(5, ma.findUpperBound(22));
        assertEq(6, ma.findUpperBound(100));
        assertEq(7, ma.findUpperBound(201));
        assertEq(8, ma.findUpperBound(224));
        assertEq(9, ma.findUpperBound(999));
        assertEq(10, ma.findUpperBound(2000));
        // greater than all elements in the array, it will return the length of the array
        assertEq(11, ma.findUpperBound(2001));
    }

    // 目标数组元素个数为0
    function test_FindUpperBound_WithZeroLength() external {
        ma.clearArrUint();
        assertEq(ma.getLength(0), 0);
        // return 0 when the target array is empty
        assertEq(0, ma.findUpperBound(0));
        assertEq(0, ma.findUpperBound(1));
    }
}

ps:\ 本人热爱图灵,热爱中本聪,热爱V神。 以下是我个人的公众号,如果有技术问题可以关注我的公众号来跟我交流。 同时我也会在这个公众号上每周更新我的原创文章,喜欢的小伙伴或者老伙计可以支持一下! 如果需要转发,麻烦注明作者。十分感谢!

1.jpeg

公众号名称:后现代泼痞浪漫主义奠基人

点赞 3
收藏 2
分享
本文参与登链社区写作激励计划 ,好文好收益,欢迎正在阅读的你也加入。

0 条评论

请先 登录 后评论
Michael.W
Michael.W
0x93E7...0000
狂热的区块链爱好者