Skip to content

Staking contract

The following staking contract is designed to allow users to stake tokens in exchange for reward tokens over a specified duration. Here's a quick summary of how it operates and what functionalities it supports:

Key Features:

  1. Token staking and unstaking:

    • Users can stake an ERC20 token, specified at deployment.
    • Users can withdraw their staked tokens at any time.
  2. Reward calculation and distribution:

    • The rewards are distributed as an ERC20, also specified at deployment (can be different from the staking token).
    • Rewards are calculated based on the duration of staking and the amount the user staked relative to the total staked amount by all users.
    • A user’s reward accumulates over time up until the reward period's end and can be claimed anytime by the user.
  3. Dynamic reward rates:

    • The reward rate is determined by the total amount of reward tokens over a set period (duration).
    • The reward rate can be adjusted during the rewards period if new rewards are added before the current reward period finishes.
    • Even after a reward period finishes, a new reward duration and new rewards can be set up if desired.
  4. Ownership and administration:

    • Only the owner of the contract can set the rewards amount and duration.

The following implementation is the Cairo adaptation of the Solidity by Example - Staking Rewards contract. It includes a small adaptation to keep track of the amount of total distributed reward tokens and emit an event when the remaining reward token amount reaches 0.

use starknet::ContractAddress;
 
#[starknet::interface]
pub trait IStakingContract<TContractState> {
    fn set_reward_amount(ref self: TContractState, amount: u256);
    fn set_reward_duration(ref self: TContractState, duration: u256);
    fn stake(ref self: TContractState, amount: u256);
    fn withdraw(ref self: TContractState, amount: u256);
    fn get_rewards(self: @TContractState, account: ContractAddress) -> u256;
    fn claim_rewards(ref self: TContractState);
}
 
mod Errors {
    pub const NULL_REWARDS: felt252 = 'Reward amount must be > 0';
    pub const NOT_ENOUGH_REWARDS: felt252 = 'Reward amount must be > balance';
    pub const NULL_AMOUNT: felt252 = 'Amount must be > 0';
    pub const NULL_DURATION: felt252 = 'Duration must be > 0';
    pub const UNFINISHED_DURATION: felt252 = 'Reward duration not finished';
    pub const NOT_OWNER: felt252 = 'Caller is not the owner';
    pub const NOT_ENOUGH_BALANCE: felt252 = 'Balance too low';
}
 
#[starknet::contract]
pub mod StakingContract {
    use core::starknet::event::EventEmitter;
    use core::num::traits::Zero;
    use starknet::{ContractAddress, get_caller_address, get_block_timestamp, get_contract_address};
    use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait};
    use starknet::storage::{
        Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePointerReadAccess,
        StoragePointerWriteAccess,
    };
 
    #[storage]
    struct Storage {
        pub staking_token: IERC20Dispatcher,
        pub reward_token: IERC20Dispatcher,
        pub owner: ContractAddress,
        pub reward_rate: u256,
        pub duration: u256,
        pub current_reward_per_staked_token: u256,
        pub finish_at: u256,
        // last time an operation (staking / withdrawal / rewards claimed) was registered
        pub last_updated_at: u256,
        pub last_user_reward_per_staked_token: Map::<ContractAddress, u256>,
        pub unclaimed_rewards: Map::<ContractAddress, u256>,
        pub total_distributed_rewards: u256,
        // total amount of staked tokens
        pub total_supply: u256,
        // amount of staked tokens per user
        pub balance_of: Map::<ContractAddress, u256>,
    }
 
    #[event]
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub enum Event {
        Deposit: Deposit,
        Withdrawal: Withdrawal,
        RewardsFinished: RewardsFinished,
    }
 
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub struct Deposit {
        pub user: ContractAddress,
        pub amount: u256,
    }
 
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub struct Withdrawal {
        pub user: ContractAddress,
        pub amount: u256,
    }
 
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub struct RewardsFinished {
        pub msg: felt252,
    }
 
    #[constructor]
    fn constructor(
        ref self: ContractState,
        staking_token_address: ContractAddress,
        reward_token_address: ContractAddress,
    ) {
        self.staking_token.write(IERC20Dispatcher { contract_address: staking_token_address });
        self.reward_token.write(IERC20Dispatcher { contract_address: reward_token_address });
 
        self.owner.write(get_caller_address());
    }
 
    #[abi(embed_v0)]
    impl StakingContract of super::IStakingContract<ContractState> {
        fn set_reward_duration(ref self: ContractState, duration: u256) {
            self.only_owner();
 
            assert(duration > 0, super::Errors::NULL_DURATION);
 
            // can only set duration if the previous duration has already finished
            assert(
                self.finish_at.read() < get_block_timestamp().into(),
                super::Errors::UNFINISHED_DURATION,
            );
 
            self.duration.write(duration);
        }
 
        fn set_reward_amount(ref self: ContractState, amount: u256) {
            self.only_owner();
            self.update_rewards(Zero::zero());
 
            assert(amount > 0, super::Errors::NULL_REWARDS);
            assert(self.duration.read() > 0, super::Errors::NULL_DURATION);
 
            let block_timestamp: u256 = get_block_timestamp().into();
 
            let rate = if self.finish_at.read() < block_timestamp {
                amount / self.duration.read()
            } else {
                let remaining_rewards = self.reward_rate.read()
                    * (self.finish_at.read() - block_timestamp);
                (remaining_rewards + amount) / self.duration.read()
            };
 
            assert(
                self.reward_token.read().balance_of(get_contract_address()) >= rate
                    * self.duration.read(),
                super::Errors::NOT_ENOUGH_REWARDS,
            );
 
            self.reward_rate.write(rate);
 
            // even if the previous reward duration has not finished, we reset the finish_at
            // variable
            self.finish_at.write(block_timestamp + self.duration.read());
            self.last_updated_at.write(block_timestamp);
 
            // reset total distributed rewards
            self.total_distributed_rewards.write(0);
        }
 
        fn stake(ref self: ContractState, amount: u256) {
            assert(amount > 0, super::Errors::NULL_AMOUNT);
 
            let user = get_caller_address();
            self.update_rewards(user);
 
            self.balance_of.write(user, self.balance_of.read(user) + amount);
            self.total_supply.write(self.total_supply.read() + amount);
            self.staking_token.read().transfer_from(user, get_contract_address(), amount);
 
            self.emit(Deposit { user, amount });
        }
 
        fn withdraw(ref self: ContractState, amount: u256) {
            assert(amount > 0, super::Errors::NULL_AMOUNT);
 
            let user = get_caller_address();
 
            assert(
                self.staking_token.read().balance_of(user) >= amount,
                super::Errors::NOT_ENOUGH_BALANCE,
            );
 
            self.update_rewards(user);
 
            self.balance_of.write(user, self.balance_of.read(user) - amount);
            self.total_supply.write(self.total_supply.read() - amount);
            self.staking_token.read().transfer(user, amount);
 
            self.emit(Withdrawal { user, amount });
        }
 
        fn get_rewards(self: @ContractState, account: ContractAddress) -> u256 {
            self.unclaimed_rewards.read(account) + self.compute_new_rewards(account)
        }
 
        fn claim_rewards(ref self: ContractState) {
            let user = get_caller_address();
            self.update_rewards(user);
 
            let rewards = self.unclaimed_rewards.read(user);
 
            if rewards > 0 {
                self.unclaimed_rewards.write(user, 0);
                self.reward_token.read().transfer(user, rewards);
            }
        }
    }
 
    #[generate_trait]
    impl PrivateFunctions of PrivateFunctionsTrait {
        // call this function every time a user (including owner) performs a state-modifying action
        fn update_rewards(ref self: ContractState, account: ContractAddress) {
            self
                .current_reward_per_staked_token
                .write(self.compute_current_reward_per_staked_token());
 
            self.last_updated_at.write(self.last_time_applicable());
 
            if account.is_non_zero() {
                self.distribute_user_rewards(account);
 
                self
                    .last_user_reward_per_staked_token
                    .write(account, self.current_reward_per_staked_token.read());
 
                self.send_rewards_finished_event();
            }
        }
 
        fn distribute_user_rewards(ref self: ContractState, account: ContractAddress) {
            // compute earned rewards since last update for the user `account`
            let user_rewards = self.get_rewards(account);
            self.unclaimed_rewards.write(account, user_rewards);
 
            // track amount of total rewards distributed
            self
                .total_distributed_rewards
                .write(self.total_distributed_rewards.read() + user_rewards);
        }
 
        fn send_rewards_finished_event(ref self: ContractState) {
            // check whether we should send a RewardsFinished event
            if self.last_updated_at.read() == self.finish_at.read() {
                let total_rewards = self.reward_rate.read() * self.duration.read();
 
                if total_rewards != 0 && self.total_distributed_rewards.read() == total_rewards {
                    // owner should set up NEW rewards into the contract
                    self.emit(RewardsFinished { msg: 'Rewards all distributed' });
                } else {
                    // owner should set up rewards into the contract (or add duration by setting up
                    // rewards)
                    self.emit(RewardsFinished { msg: 'Rewards not active yet' });
                }
            }
        }
 
        fn compute_current_reward_per_staked_token(self: @ContractState) -> u256 {
            if self.total_supply.read() == 0 {
                self.current_reward_per_staked_token.read()
            } else {
                self.current_reward_per_staked_token.read()
                    + self.reward_rate.read()
                        * (self.last_time_applicable() - self.last_updated_at.read())
                        / self.total_supply.read()
            }
        }
 
        fn compute_new_rewards(self: @ContractState, account: ContractAddress) -> u256 {
            self.balance_of.read(account)
                * (self.current_reward_per_staked_token.read()
                    - self.last_user_reward_per_staked_token.read(account))
        }
 
        #[inline(always)]
        fn last_time_applicable(self: @ContractState) -> u256 {
            Self::min(self.finish_at.read(), get_block_timestamp().into())
        }
 
        #[inline(always)]
        fn min(x: u256, y: u256) -> u256 {
            if (x <= y) {
                x
            } else {
                y
            }
        }
 
        fn only_owner(self: @ContractState) {
            let caller = get_caller_address();
            assert(caller == self.owner.read(), super::Errors::NOT_OWNER);
        }
    }
}
Powered By Nethermind