Skip to content

Needless allocations in reward() and is_terminated() for  #1080

@hespanha

Description

@hespanha

The two functions
reward(::TicTacToeEnv,::Player)
s_terminated(::TicTacToeEnv)
result in a small but needless allocation due to a type instability in call to get_tic_tac_toe_state_info()

To see this, you can use:

using ReinforcementLearning
using BenchmarkTools
env = TicTacToeEnv()
display(@benchmark reward($env))
display(@benchmark is_terminated($env))

I was able to fix this problem (and save about 7% of time) with 3 small changes to TicTacToeEnv.jl. There may be other ways to fix this, but these were the simplest changes I could find.

import ReinforcementLearningEnvironments: get_tic_tac_toe_state_info
function ReinforcementLearningEnvironments.get_tic_tac_toe_state_info()
    if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
        @info "initializing tictactoe state info cache..."
        t = @elapsed begin
            n = 1
            root = TicTacToeEnv()
            RLEnvs.TIC_TAC_TOE_STATE_INFO[root] =
                (index=n, is_terminated=false, winner=nothing)
            walk(root) do env
                if !haskey(TIC_TAC_TOE_STATE_INFO, env)
                    n += 1
                    has_empty_pos = any(view(env.board, :, :, 1))
                    w = if is_win(env, Player(:Cross))
                        Player(:Cross)
                    elseif is_win(env, Player(:Nought))
                        Player(:Nought)
                    else
                        nothing
                    end
                    RLEnvs.TIC_TAC_TOE_STATE_INFO[env] = (
                        index=n,
                        is_terminated=!(has_empty_pos && isnothing(w)),
                        winner=w,
                    )
                end
            end
        end
        @info "finished initializing tictactoe state info cache in $t seconds"
    end
    # CHANGE: declare type explicitly
    RLEnvs.TIC_TAC_TOE_STATE_INFO::Dict{TicTacToeEnv,@NamedTuple{index::Int64, is_terminated::Bool, winner::Union{Nothing,Player}}}
end

import ReinforcementLearning: reward
function RLBase.reward(env::TicTacToeEnv, player::Player)
    # CHANGE: only call get_tic_tac_toe_state_info() if necessary
    if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
        info_env = get_tic_tac_toe_state_info()[env]
    else
        info_env = RLEnvs.TIC_TAC_TOE_STATE_INFO[env]
    end
    if info_env.is_terminated
        winner = info_env.winner
        if isnothing(winner)
            0
        elseif winner === player
            1
        else
            -1
        end
    else
        0
    end
end

import ReinforcementLearning: is_terminated
function RLBase.is_terminated(env::TicTacToeEnv)
    # CHANGE: only call get_tic_tac_toe_state_info() if necessary
    if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
        return info_env = get_tic_tac_toe_state_info()[env].is_terminated
    else
        return info_env = RLEnvs.TIC_TAC_TOE_STATE_INFO[env].is_terminated
    end
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions