Skip to content

Constant Product AMM

This is the Cairo adaptation of the Solidity by Example - Constant Product AMM.

In this contract, we implement a simple Automated Market Maker (AMM) following the constant product formula: (xy=k)( x \cdot y = k ). This formula ensures that the product of the two token reserves (xx and yy representing the tokens being swapper) remains constant, regardless of trades. Here, we provide liquidity pools that allow users to trade between two tokens or add and remove liquidity from the pool.

Key Concepts

  1. approve() before swap or adding liquidity: Before interacting with the AMM (whether through swaps or adding liquidity), the user must approve the contract to spend their tokens. This is done by calling the approve() function on the ERC20 token contracts, allowing the AMM to transfer the required tokens on behalf of the user.

  2. Constant Product Formula for Swaps: The swap function operates based on the constant product formula (xy=k)( x \cdot y = k ), where xx and yy are the token reserves. When a user swaps one token for another, the product of the reserves remains constant, which determines how much of the other token the user will receive.

  3. Shares and Token Ratios for Liquidity: When adding liquidity, users provide both tokens in the ratio of the current reserves. The number of shares (liquidity tokens) the user receives represents their contribution to the pool. Similarly, when removing liquidity, users receive back tokens proportional to the number of shares they burn.

use starknet::ContractAddress;
 
#[starknet::interface]
pub trait IConstantProductAmm<TContractState> {
    fn swap(ref self: TContractState, token_in: ContractAddress, amount_in: u256) -> u256;
    fn add_liquidity(ref self: TContractState, amount0: u256, amount1: u256) -> u256;
    fn remove_liquidity(ref self: TContractState, shares: u256) -> (u256, u256);
}
 
#[starknet::contract]
pub mod ConstantProductAmm {
    use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess};
    use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait};
    use starknet::{ContractAddress, get_caller_address, get_contract_address};
    use starknet::storage::{Map, StorageMapReadAccess, StorageMapWriteAccess};
    use core::num::traits::Sqrt;
 
    #[storage]
    struct Storage {
        token0: IERC20Dispatcher,
        token1: IERC20Dispatcher,
        reserve0: u256,
        reserve1: u256,
        total_supply: u256,
        balance_of: Map::<ContractAddress, u256>,
        // Fee 0 - 1000 (0% - 100%, 1 decimal places)
        // E.g. 3 = 0.3%
        fee: u16,
    }
 
    #[constructor]
    fn constructor(
        ref self: ContractState, token0: ContractAddress, token1: ContractAddress, fee: u16
    ) {
        // assert(fee <= 1000, 'fee > 1000');
        self.token0.write(IERC20Dispatcher { contract_address: token0 });
        self.token1.write(IERC20Dispatcher { contract_address: token1 });
        self.fee.write(fee);
    }
 
    #[generate_trait]
    impl PrivateFunctions of PrivateFunctionsTrait {
        fn _mint(ref self: ContractState, to: ContractAddress, amount: u256) {
            self.balance_of.write(to, self.balance_of.read(to) + amount);
            self.total_supply.write(self.total_supply.read() + amount);
        }
 
        fn _burn(ref self: ContractState, from: ContractAddress, amount: u256) {
            self.balance_of.write(from, self.balance_of.read(from) - amount);
            self.total_supply.write(self.total_supply.read() - amount);
        }
 
        fn _update(ref self: ContractState, reserve0: u256, reserve1: u256) {
            self.reserve0.write(reserve0);
            self.reserve1.write(reserve1);
        }
 
        #[inline(always)]
        fn select_token(self: @ContractState, token: ContractAddress) -> bool {
            assert(
                token == self.token0.read().contract_address
                    || token == self.token1.read().contract_address,
                'invalid token'
            );
            token == self.token0.read().contract_address
        }
 
        #[inline(always)]
        fn min(x: u256, y: u256) -> u256 {
            if (x <= y) {
                x
            } else {
                y
            }
        }
    }
 
    #[abi(embed_v0)]
    impl ConstantProductAmm of super::IConstantProductAmm<ContractState> {
        fn swap(ref self: ContractState, token_in: ContractAddress, amount_in: u256) -> u256 {
            assert(amount_in > 0, 'amount in = 0');
            let is_token0: bool = self.select_token(token_in);
 
            let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = (
                self.token0.read(), self.token1.read()
            );
            let (reserve0, reserve1): (u256, u256) = (self.reserve0.read(), self.reserve1.read());
            let (
                token_in, token_out, reserve_in, reserve_out
            ): (IERC20Dispatcher, IERC20Dispatcher, u256, u256) =
                if (is_token0) {
                (token0, token1, reserve0, reserve1)
            } else {
                (token1, token0, reserve1, reserve0)
            };
 
            let caller = get_caller_address();
            let this = get_contract_address();
            token_in.transfer_from(caller, this, amount_in);
 
            // How much dy for dx?
            // xy = k
            // (x + dx)(y - dy) = k
            // y - dy = k / (x + dx)
            // y - k / (x + dx) = dy
            // y - xy / (x + dx) = dy
            // (yx + ydx - xy) / (x + dx) = dy
            // ydx / (x + dx) = dy
 
            let amount_in_with_fee = (amount_in * (1000 - self.fee.read().into()) / 1000);
            let amount_out = (reserve_out * amount_in_with_fee) / (reserve_in + amount_in_with_fee);
 
            token_out.transfer(caller, amount_out);
 
            self._update(self.token0.read().balance_of(this), self.token1.read().balance_of(this));
            amount_out
        }
 
        fn add_liquidity(ref self: ContractState, amount0: u256, amount1: u256) -> u256 {
            let caller = get_caller_address();
            let this = get_contract_address();
            let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = (
                self.token0.read(), self.token1.read()
            );
 
            token0.transfer_from(caller, this, amount0);
            token1.transfer_from(caller, this, amount1);
 
            // How much dx, dy to add?
            //
            // xy = k
            // (x + dx)(y + dy) = k'
            //
            // No price change, before and after adding liquidity
            // x / y = (x + dx) / (y + dy)
            //
            // x(y + dy) = y(x + dx)
            // x * dy = y * dx
            //
            // x / y = dx / dy
            // dy = y / x * dx
 
            let (reserve0, reserve1): (u256, u256) = (self.reserve0.read(), self.reserve1.read());
            if (reserve0 > 0 || reserve1 > 0) {
                assert(reserve0 * amount1 == reserve1 * amount0, 'x / y != dx / dy');
            }
 
            // How many shares to mint?
            //
            // f(x, y) = value of liquidity
            // We will define f(x, y) = sqrt(xy)
            //
            // L0 = f(x, y)
            // L1 = f(x + dx, y + dy)
            // T = total shares
            // s = shares to mint
            //
            // Total shares should increase proportional to increase in liquidity
            // L1 / L0 = (T + s) / T
            //
            // L1 * T = L0 * (T + s)
            //
            // (L1 - L0) * T / L0 = s
 
            // Claim
            // (L1 - L0) / L0 = dx / x = dy / y
            //
            // Proof
            // --- Equation 1 ---
            // (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy)
            //
            // dx / dy = x / y so replace dy = dx * y / x
            //
            // --- Equation 2 ---
            // Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy)
            //
            // Multiply by sqrt(x) / sqrt(x)
            // Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y)
            //            = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2))) / (sqrt(y)sqrt(x^2))
            // sqrt(y) on top and bottom cancels out
            //
            // --- Equation 3 ---
            // Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / sqrt(x^2)
            // = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2)
            // = ((x + dx) - x) / x
            // = dx / x
            // Since dx / dy = x / y,
            // dx / x = dy / y
            //
            // Finally
            // (L1 - L0) / L0 = dx / x = dy / y
 
            let total_supply = self.total_supply.read();
            let shares = if (total_supply == 0) {
                (amount0 * amount1).sqrt().into()
            } else {
                PrivateFunctions::min(
                    amount0 * total_supply / reserve0, amount1 * total_supply / reserve1
                )
            };
            assert(shares > 0, 'shares = 0');
            self._mint(caller, shares);
 
            self._update(self.token0.read().balance_of(this), self.token1.read().balance_of(this));
            shares
        }
 
        fn remove_liquidity(ref self: ContractState, shares: u256) -> (u256, u256) {
            let caller = get_caller_address();
            let this = get_contract_address();
            let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = (
                self.token0.read(), self.token1.read()
            );
 
            // Claim
            // dx, dy = amount of liquidity to remove
            // dx = s / T * x
            // dy = s / T * y
            //
            // Proof
            // Let's find dx, dy such that
            // v / L = s / T
            //
            // where
            // v = f(dx, dy) = sqrt(dxdy)
            // L = total liquidity = sqrt(xy)
            // s = shares
            // T = total supply
            //
            // --- Equation 1 ---
            // v = s / T * L
            // sqrt(dxdy) = s / T * sqrt(xy)
            //
            // Amount of liquidity to remove must not change price so
            // dx / dy = x / y
            //
            // replace dy = dx * y / x
            // sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x)
            //
            // Divide both sides of Equation 1 with sqrt(y / x)
            // dx = s / T * sqrt(xy) / sqrt(y / x)
            // = s / T * sqrt(x^2) = s / T * x
            //
            // Likewise
            // dy = s / T * y
 
            // bal0 >= reserve0
            // bal1 >= reserve1
            let (bal0, bal1): (u256, u256) = (token0.balance_of(this), token1.balance_of(this));
 
            let total_supply = self.total_supply.read();
            let (amount0, amount1): (u256, u256) = (
                (shares * bal0) / total_supply, (shares * bal1) / total_supply
            );
            assert(amount0 > 0 && amount1 > 0, 'amount0 or amount1 = 0');
 
            self._burn(caller, shares);
            self._update(bal0 - amount0, bal1 - amount1);
 
            token0.transfer(caller, amount0);
            token1.transfer(caller, amount1);
            (amount0, amount1)
        }
    }
}
Powered By Nethermind