Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
547 changes: 547 additions & 0 deletions alignit/cameras/realsense.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions alignit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ class ModelConfig:
default="alignnet_model.pth",
metadata={"help": "Path to save/load trained model"},
)
use_depth_input: bool = field(
default=True,
metadata={"help": "Whether to use depth input for the model"}
)
depth_hidden_dim: int = field(
default=128, metadata={"help": "Output dimension of depth CNN"}
)


@dataclass
Expand Down
45 changes: 22 additions & 23 deletions alignit/infere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,19 @@
import time
import draccus
from alignit.config import InferConfig

import torch

from alignit.models.alignnet import AlignNet
from alignit.utils.zhou import sixd_se3
from alignit.utils.tfs import print_pose, are_tfs_close
from alignit.robots.xarmsim import XarmSim
from alignit.robots.xarm import Xarm


@draccus.wrap()
def main(cfg: InferConfig):
"""Run inference/alignment using configuration parameters."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# load model from file
net = AlignNet(
backbone_name=cfg.model.backbone,
backbone_weights=cfg.model.backbone_weights,
Expand All @@ -28,48 +24,54 @@ def main(cfg: InferConfig):
vector_hidden_dim=cfg.model.vector_hidden_dim,
output_dim=cfg.model.output_dim,
feature_agg=cfg.model.feature_agg,
use_depth_input=cfg.model.use_depth_input,
)
net.load_state_dict(torch.load(cfg.model.path, map_location=device))
net.to(device)
net.eval()

robot = XarmSim()

# Set initial pose from config
start_pose = t3d.affines.compose(
[0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2)

robot.servo_to_pose(start_pose, lin_tol=1e-2)
iteration = 0
iterations_within_tolerance = 0
ang_tol_rad = np.deg2rad(cfg.ang_tolerance)

try:
while True:
observation = robot.get_observation()
images = [observation["camera.rgb"].astype(np.float32) / 255.0]
rgb_image = observation["camera.rgb"].astype(np.float32) / 255.0
depth_image = observation["camera.rgb.depth"].astype(np.float32) / 1000.0
rgb_image_tensor = (
torch.from_numpy(np.array(rgb_image))
.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
.unsqueeze(0)
.to(device)
)

# Convert images to tensor and reshape from HWC to CHW format
images_tensor = (
torch.from_numpy(np.array(images))
.permute(0, 3, 1, 2)
.unsqueeze(0)
depth_image_tensor = (
torch.from_numpy(np.array(depth_image))
.unsqueeze(0) # Add channel dimension: (1, H, W)
.unsqueeze(0) # Add batch dimension: (1, 1, H, W)
.to(device)
)

if cfg.debug_output:
print(f"Max pixel value: {torch.max(images_tensor)}")
rgb_images_batch = rgb_image_tensor.unsqueeze(1)
depth_images_batch = depth_image_tensor.unsqueeze(1)


start = time.time()
with torch.no_grad():
relative_action = net(images_tensor)
relative_action = net(rgb_images_batch, depth_images=depth_images_batch)
relative_action = relative_action.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action)

if cfg.debug_output:
print_pose(relative_action)

relative_action [:3,:3] = relative_action[:3,:3] @ relative_action[:3,:3] @ relative_action[:3,:3]

# Check convergence
if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
Expand All @@ -82,10 +84,10 @@ def main(cfg: InferConfig):
print("Alignment achieved - stopping.")
break

target_pose = robot.pose() @ relative_action
action = robot.pose() @ relative_action
iteration += 1
action = {
"pose": target_pose,
"pose": action,
"gripper.pos": 1.0,
}
robot.send_action(action)
Expand All @@ -96,12 +98,9 @@ def main(cfg: InferConfig):
break

time.sleep(10.0)

except KeyboardInterrupt:
print("\nExiting...")

robot.disconnect()


if __name__ == "__main__":
main()
67 changes: 57 additions & 10 deletions alignit/models/alignnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
class AlignNet(nn.Module):
def __init__(
self,
backbone_name="efficientnet_b0",
backbone_name="efficientnet_b0",
backbone_weights="DEFAULT",
use_vector_input=True,
use_depth_input=True, # NEW
fc_layers=[256, 128],
vector_hidden_dim=64,
depth_hidden_dim=128, # NEW
output_dim=7,
feature_agg="mean",
):
Expand All @@ -23,27 +25,39 @@ def __init__(
:param vector_hidden_dim: output dim of the vector MLP
:param output_dim: final output vector size
:param feature_agg: 'mean' or 'max' across image views
:param use_depth_input: whether to accept depth input
:param depth_hidden_dim: output dim of the depth MLP
"""
super().__init__()
self.use_vector_input = use_vector_input
self.use_depth_input = use_depth_input
self.feature_agg = feature_agg

# CNN backbone
self.backbone, self.image_feature_dim = self._build_backbone(
backbone_name, backbone_weights
)

# Linear projection of image features
self.image_fc = nn.Sequential(
nn.Linear(self.image_feature_dim, fc_layers[0]), nn.ReLU()
)

if use_depth_input:
self.depth_cnn = nn.Sequential(
nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(),
nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.depth_fc = nn.Sequential(
nn.Linear(16, depth_hidden_dim), nn.ReLU()
)
input_dim = fc_layers[0] + depth_hidden_dim
else:
input_dim = fc_layers[0]

# Optional vector input processing
if use_vector_input:
self.vector_fc = nn.Sequential(nn.Linear(1, vector_hidden_dim), nn.ReLU())
input_dim = fc_layers[0] + vector_hidden_dim
else:
input_dim = fc_layers[0]
input_dim += vector_hidden_dim

# Fully connected layers
layers = []
Expand Down Expand Up @@ -81,10 +95,11 @@ def aggregate_image_features(self, feats):
else:
raise ValueError("Invalid aggregation type")

def forward(self, rgb_images, vector_inputs=None):
def forward(self, rgb_images, vector_inputs=None, depth_images=None):
"""
:param rgb_images: Tensor of shape (B, N, 3, H, W)
:param vector_inputs: List of tensors of shape (L_i,) or None
:param depth_images: Tensor of shape (B, N, 1, H, W) or None
:return: Tensor of shape (B, output_dim)
"""
B, N, C, H, W = rgb_images.shape
Expand All @@ -93,15 +108,47 @@ def forward(self, rgb_images, vector_inputs=None):
image_feats = self.aggregate_image_features(feats)
image_feats = self.image_fc(image_feats)

features = [image_feats]

if self.use_depth_input and depth_images is not None:
depth = depth_images.view(B * N, 1, H, W)
depth_feats = self.depth_cnn(depth).view(B, N, -1)
depth_feats = self.aggregate_image_features(depth_feats)
depth_feats = self.depth_fc(depth_feats)
features.append(depth_feats)

if self.use_vector_input and vector_inputs is not None:
vec_feats = []
for vec in vector_inputs:
vec = vec.unsqueeze(1) # (L, 1)
pooled = self.vector_fc(vec).mean(dim=0) # (D,)
vec_feats.append(pooled)
vec_feats = torch.stack(vec_feats, dim=0)
fused = torch.cat([image_feats, vec_feats], dim=1)
else:
fused = image_feats
features.append(vec_feats)

fused = torch.cat(features, dim=1)
print("Fused shape:", fused.shape)

return self.head(fused) # (B, output_dim)


if __name__ == "__main__":
import time

batch_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlignNet(backbone_name="efficientnet_b0", use_vector_input=True, use_depth_input=True).to(device)

rgb_images = torch.randn(batch_size, 4, 3, 224, 224).to(device)
depth_images = torch.randn(batch_size, 4, 1, 224, 224).to(device)
vector_inputs = [torch.randn(10).to(device) for _ in range(batch_size)]

output = None
start_time = time.time()
for i in range(100):
output = model(rgb_images, vector_inputs, depth_images)
end_time = time.time()
duration_ms = ((end_time - start_time) / 100) * 1000
print(f"Inference time: {duration_ms:.3f} ms")
print(f"Optimal for {1000 / duration_ms:.2f} fps")
print("Output shape:", output.shape)
61 changes: 30 additions & 31 deletions alignit/record.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
import shutil

import transforms3d as t3d
import numpy as np
from scipy.spatial.transform import Rotation as R
from alignit.robots.xarm import Xarm
from alignit.utils.zhou import se3_sixd
import argparse # Added for command line arguments
import draccus
from alignit.config import RecordConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert the import order, these changes should not exist

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • import draccus
  • from alignit.config import RecordConfig

These are needed in order to use config parameters,no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but they were present before

from alignit.robots.xarmsim import XarmSim
from datasets import (
Dataset,
Features,
Expand All @@ -14,13 +19,6 @@
concatenate_datasets,
)

from alignit.robots.xarmsim import XarmSim
from alignit.robots.xarm import Xarm
from alignit.utils.zhou import se3_sixd
import draccus
from alignit.config import RecordConfig


def generate_spiral_trajectory(start_pose, cfg):
"""Generate spiral trajectory using configuration parameters."""
trajectory = []
Expand Down Expand Up @@ -76,13 +74,12 @@ def generate_spiral_trajectory(start_pose, cfg):

return trajectory


@draccus.wrap()
def main(cfg: RecordConfig):
"""Record alignment dataset using configuration parameters."""
robot = XarmSim()
features = Features(
{"images": Sequence(Image()), "action": Sequence(Value("float32"))}
{"images": Sequence(Image()), "action": Sequence(Value("float32")), "depth": Sequence(Image())}
)

for episode in range(cfg.episodes):
Expand All @@ -91,35 +88,36 @@ def main(cfg: RecordConfig):
robot.servo_to_pose(pose_alignment_target, lin_tol=0.015, ang_tol=0.015)

robot.servo_to_pose(
pose_alignment_target,
lin_tol=cfg.lin_tol_alignment,
ang_tol=cfg.ang_tol_alignment,
pose_alignment_target,
lin_tol=cfg.lin_tol_alignment,
ang_tol=cfg.ang_tol_alignment,
)

trajectory = generate_spiral_trajectory(pose_start, cfg.trajectory)


print(f"Generated trajectory with {len(trajectory)} poses for episode {episode+1}",flush=True)
frames = []
for pose in trajectory:
robot.servo_to_pose(
pose, lin_tol=cfg.lin_tol_trajectory, ang_tol=cfg.ang_tol_trajectory
)
current_pose = robot.pose()

action_pose = np.linalg.inv(current_pose) @ pose_alignment_target
action_sixd = se3_sixd(action_pose)

observation = robot.get_observation()
frame = {
"images": [observation["camera.rgb"].copy()],
"action": action_sixd,
}
frames.append(frame)

robot.servo_to_pose(
pose, lin_tol=cfg.lin_tol_trajectory, ang_tol=cfg.ang_tol_trajectory
)
current_pose = robot.pose()

action_pose = np.linalg.inv(current_pose) @ pose_alignment_target
action_sixd = se3_sixd(action_pose)

observation = robot.get_observation()
frame = {
"images": [observation["camera.rgb"].copy()],
"action": action_sixd,
"depth": [observation["camera.rgb.depth"].copy()],
}
frames.append(frame)
print(f"Episode {episode+1} completed with {len(frames)} frames.")

episode_dataset = Dataset.from_list(frames, features=features)
if episode == 0:
combined_dataset = episode_dataset
print("Writing do dataset to disk for the first time.")
else:
previous_dataset = load_from_disk(cfg.dataset.path)
previous_dataset = previous_dataset.cast(features)
Expand All @@ -131,9 +129,10 @@ def main(cfg: RecordConfig):
if os.path.exists(cfg.dataset.path):
shutil.rmtree(cfg.dataset.path)
shutil.move(temp_path, cfg.dataset.path)
print(f"Saved dataset to {cfg.dataset.path}")

robot.disconnect()


if __name__ == "__main__":
main()
main()
6 changes: 4 additions & 2 deletions alignit/robots/xarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def send_action(self, action):
self.robot.send_action(action)

def get_observation(self):
rgb_image = self.camera.read()
rgb_image,depth_image,acqusition_time = self.camera.async_read()

return {
"camera.rgb": rgb_image,
"rgb": rgb_image,
"depth": depth_image,
"timestamp": acqusition_time
}

def disconnect(self):
Expand Down
Loading