GenSim / cliport /generated_tasks /align_spheres_in_boxes.py
LeroyWaa's picture
add gensim code
8fc2b4e
import numpy as np
import os
import pybullet as p
import random
from cliport.tasks import primitives
from cliport.tasks.grippers import Spatula
from cliport.tasks.task import Task
from cliport.utils import utils
import numpy as np
from cliport.tasks.task import Task
from cliport.utils import utils
import pybullet as p
class AlignSpheresInBoxes(Task):
"""Pick up each sphere and place it into the box of the same color. Navigate around the blocks."""
def __init__(self):
super().__init__()
self.max_steps = 20
self.lang_template = "put the {color} sphere in the {color} box"
self.task_completed_desc = "done aligning spheres in boxes."
self.additional_reset()
def reset(self, env):
super().reset(env)
# Add boxes.
# x, y, z dimensions for the asset size
box_size = (0.12, 0.12, 0.12)
box_urdf = 'box/box-template.urdf'
box_poses = []
box_colors = ['red', 'blue', 'green', 'yellow']
for color in box_colors:
box_pose = self.get_random_pose(env, box_size)
env.add_object(box_urdf, box_pose, category='fixed', color=utils.COLORS[color])
box_poses.append(box_pose)
# Add spheres.
# x, y, z dimensions for the asset size
spheres = []
sphere_size = (0.04, 0.04, 0.04)
sphere_urdf = 'sphere/sphere-template.urdf'
for color in box_colors:
sphere_pose = self.get_random_pose(env, sphere_size)
sphere_id = env.add_object(sphere_urdf, sphere_pose, color=utils.COLORS[color])
spheres.append(sphere_id)
# Add blocks.
# x, y, z dimensions for the asset size
blocks = []
block_size = (0.04, 0.04, 0.04)
block_urdf = 'block/small.urdf'
for _ in range(10):
block_pose = self.get_random_pose(env, block_size)
block_id = env.add_object(block_urdf, block_pose)
blocks.append(block_id)
# Goal: each sphere is in a box of the same color.
for i in range(len(spheres)):
self.add_goal(objs=[spheres[i]], matches=np.ones((1, 1)), targ_poses=[box_poses[i]], replace=False,
rotations=True, metric='pose', params=None, step_max_reward=1/len(spheres))
self.lang_goals.append(self.lang_template.format(color=box_colors[i]))