Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype choice in step type/functions #256

Open
1 task done
thomashirtz opened this issue Nov 3, 2024 · 7 comments
Open
1 task done

Add dtype choice in step type/functions #256

thomashirtz opened this issue Nov 3, 2024 · 7 comments
Labels
enhancement New feature or request

Comments

@thomashirtz
Copy link

thomashirtz commented Nov 3, 2024

Is your feature request related to a problem? Please describe

In a personal project, I had to be very efficient in the memory management, my reward were taking a lot of space. In my case I needed to change the reward and the discount to float16 instead of float32. I had to copy over the types file to do the modification locally. However I feel like some use case may need this flexibility.

Describe the solution you'd like

Give an extra parameter to the step functions to give dtype. (Example with one of them)

def truncation(
    reward: Array,
    observation: Observation,
    discount: Optional[Array] = None,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
    dtype: jnp.dtype = jnp.float32,
) -> TimeStep:
    """Returns a `TimeStep` with `step_type` set to `StepType.LAST`.

    Args:
        reward: array.
        observation: array or tree of arrays.
        discount: array.
        extras: environment metric(s) or information returned by the environment but
            not observed by the agent (hence not in the observation). For example, it
            could be whether an invalid action was taken. In most environments, extras
            is None.
        shape: optional parameter to specify the shape of the rewards and discounts.
            Allows multi-agent environment compatibility. Defaults to () for
            scalar reward and discount.
    Returns:
        TimeStep identified as the truncation of an episode.
    """
    discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
    extras = extras or {}
    return TimeStep(
        step_type=StepType.LAST,
        reward=reward,
        discount=discount,
        observation=observation,
        extras=extras,
    )

I would be happy to do the PR.


Misc

  • Check for duplicate requests.
@thomashirtz thomashirtz added the enhancement New feature or request label Nov 3, 2024
@sash-a
Copy link
Collaborator

sash-a commented Nov 4, 2024

I think that would be a nice addition, happy to review it 😄

@sash-a
Copy link
Collaborator

sash-a commented Nov 4, 2024

To be honest I think the discounts should probably be booleans while they are stored in the timestep because for me they just indicated end of episode, but I think this would add nice flexibility

@thomashirtz
Copy link
Author

To be honest I think the discounts should probably be booleans while they are stored in the timestep because for me they just indicated end of episode, but I think this would add nice flexibility

I'm fine with both, as long as it doesn't take too much space. I go with argument set by default to boolean ? or just boolean ?

@sash-a
Copy link
Collaborator

sash-a commented Nov 5, 2024

My only issue is that this strays from the original dm_env api where it is a float so it can represent both RL discount (gamma) and done.

Let's definitely add it as an argument, but for the default I'm not sure if boolean or float32 is best @clement-bonnet any thoughts on this?

@clement-bonnet
Copy link
Collaborator

To my knowledge, having the discount as a float is more common than as a boolean for the reasons you mentioned @sash-a. I would keep it a float unless there are strong reasons to do otherwise :)

@sash-a
Copy link
Collaborator

sash-a commented Nov 5, 2024

Great then if you could add the argument with a default of float32, I'm happy to accept the PR

@thomashirtz
Copy link
Author

Great then if you could add the argument with a default of float32, I'm happy to accept the PR

The PR is available for review :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants