Skip to content

Commit

Permalink
Merge pull request #2 from JuliaReinforcementLearning/fix_examples
Browse files Browse the repository at this point in the history
fix examples
  • Loading branch information
findmyway authored Sep 8, 2018
2 parents 631626b + 3db2518 commit 0771f14
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/cartpole.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Using PyCall is rather slow. Please compare to https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironmentClassicControl.jl/blob/master/examples/cartpole.jl
using ReinforcementLearningEnvironmentGym
using ReinforcementLearningEnvironmentGym, ReinforcementLearning

env = GymEnv("CartPole-v0")
rlsetup = RLSetup(ActorCriticPolicyGradient(ns = 4, na = 2, α = .02,
Expand Down
6 changes: 3 additions & 3 deletions examples/cartpoleDQN.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReinforcementLearningEnvironmentGym, Flux
using ReinforcementLearningEnvironmentGym, Flux, ReinforcementLearning
# List all envs

listallenvs()
Expand All @@ -14,12 +14,12 @@ learner = DQN(Chain(Dense(4, 48, relu), Dense(48, 24, relu), Dense(24, 2)),
x = RLSetup(learner, env, ConstantNumberEpisodes(10),
callbacks = [Progress(), EvaluationPerEpisode(TimeSteps()),
Visualize(wait = 0)])
info("Before learning.")
@info("Before learning.")
run!(x)
pop!(x.callbacks)
x.stoppingcriterion = ConstantNumberEpisodes(400)
@time learn!(x)
x.stoppingcriterion = ConstantNumberEpisodes(10)
push!(x.callbacks, Visualize(wait = 0))
info("After learning.")
@info("After learning.")
run!(x)
30 changes: 21 additions & 9 deletions src/ReinforcementLearningEnvironmentGym.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
module ReinforcementLearningEnvironmentGym
export GymEnv, listallenvs, interact!, reset!, getstate, plotenv, actionspace, sample
using ReinforcementLearningBase
import ReinforcementLearningBase:interact!, reset!, getstate, plotenv, actionspace
using PyCall
Expand All @@ -13,19 +12,19 @@ end
function gymspace2jlspace(s::PyObject)
spacetype = s[:__class__][:__name__]
if spacetype == "Box" BoxSpace(s[:low], s[:high])
elseif spacetype == "Discrete" DiscreteSpace(s[:n], 0)
elseif spacetype == "Discrete" DiscreteSpace(s[:n], 1)
elseif spacetype == "MultiBinary" MultiBinarySpace(s[:n])
elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 0)
elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 1)
elseif spacetype == "Tuple" map(gymspace2jlspace, s[:spaces])
elseif spacetype == "Dict" Dict(map((k, v) -> (k, gymspace2jlspace(v)), s[:spaces]))
else error("Don't know how to convert [$(spacetype)]")
end
end

struct GymEnv <: AbstractEnv
struct GymEnv{Ta<:AbstractSpace, To<:AbstractSpace} <: AbstractEnv
pyobj::PyObject
observationspace::AbstractSpace
actionspace::AbstractSpace
observationspace::To
actionspace::Ta
state::PyObject
end

Expand All @@ -34,25 +33,37 @@ function GymEnv(name::String)
obsspace = gymspace2jlspace(pyenv[:observation_space])
actspace = gymspace2jlspace(pyenv[:action_space])
state = PyNULL()
GymEnv(pyenv, obsspace, actspace, state)
env = GymEnv(pyenv, obsspace, actspace, state)
reset!(env) # state needs to be set to call defaultbuffer in RL
env
end

function interact!(env::GymEnv, action)
pycall!(env.state, env.pyobj[:step], PyVector, action)
(observation=env.state[1], reward=env.state[2], isdone=env.state[3])
end

function interact!(env::GymEnv{DiscreteSpace}, action::Int)
pycall!(env.state, env.pyobj[:step], PyVector, action - 1)
(observation=env.state[1], reward=env.state[2], isdone=env.state[3])
end

function interact!(env::GymEnv{MultiDiscreteSpace}, action::AbstractArray{Int})
pycall!(env.state, env.pyobj[:step], PyVector, action .- 1)
(observation=env.state[1], reward=env.state[2], isdone=env.state[3])
end

"Not very useful, kept for compat"
function getstate(env::GymEnv)
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type)
(observation=env.state[1], isdone=env.state[3])
else
# env has just been reseted
(observation=env.state, isdone=false)
(observation=Float64.(env.state), isdone=false)
end
end

reset!(env::GymEnv) = (observation=pycall!(env.state, env.pyobj[:reset], PyArray),)
reset!(env::GymEnv) = (observation=Float64.(pycall!(env.state, env.pyobj[:reset], PyArray)),)
plotenv(env::GymEnv) = env.pyobj[:render]()
actionspace(env::GymEnv) = env.actionspace

Expand All @@ -71,4 +82,5 @@ function listallenvs(pattern = r"")
end
end

export GymEnv, listallenvs
end # module
10 changes: 2 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
using ReinforcementLearningEnvironmentGym
using PyCall
using Test
using ReinforcementLearningEnvironmentGym, Test, ReinforcementLearningBase

for x in ["CartPole-v0"]
env = GymEnv(x)
@test typeof(reset!(env)) == NamedTuple{(:observation,), Tuple{PyArray{Float64, 1}}}
@test typeof(interact!(env, 1)) == NamedTuple{(:observation, :reward, :isdone), Tuple{Array{Float64, 1}, Float64, Bool}}
end
test_envinterface(GymEnv("CartPole-v0"))

0 comments on commit 0771f14

Please sign in to comment.