Skip to content

milkv kinfer for zbot2 #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@
[submodule "buildroot"]
path = buildroot
url = https://github.com/zeroth-robotics/buildroot
[submodule "kinfer"]
path = kinfer
url = https://github.com/kscalelabs/kinfer.git
branch = milkv-infer
1 change: 1 addition & 0 deletions kinfer
Submodule kinfer added at de1445
1 change: 1 addition & 0 deletions runtime/kos_platform/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "KOS platform for Zeroth-01"

[dependencies]
kos_core = { version = "0.1.2", path = "../../kos_core" }
kinfer = { path = "../../kinfer/kinfer/rust" }
async-trait = "0.1"
eyre = "0.6"
serde = { version = "1.0", features = ["derive"] }
Expand Down
56 changes: 48 additions & 8 deletions runtime/kos_platform/src/actuator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,64 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::RwLock;
use tonic::{Request, Response, Status};
use lazy_static::lazy_static;

pub struct ZBotActuator {
supervisor: Arc<RwLock<FeetechSupervisor>>,
}
pub const ZBOT_ALL_ACTUATOR_IDS: [u32; 16] = [1, 2, 3, 4, 5,
6, 7, 8, 9, 10,
11, 12, 13,
14, 15, 16];


lazy_static! {
pub static ref JOINT_NAME_TO_ID: HashMap<String, u32> = {
let mut map = HashMap::new();
// Right leg
map.insert("right_ankle_pitch".to_string(), 1);
map.insert("right_knee_pitch".to_string(), 2);
map.insert("right_hip_roll".to_string(), 3);
map.insert("right_hip_yaw".to_string(), 4);
map.insert("right_hip_pitch".to_string(), 5);

// Left leg
map.insert("left_ankle_pitch".to_string(), 6);
map.insert("left_knee_pitch".to_string(), 7);
map.insert("left_hip_roll".to_string(), 8);
map.insert("left_hip_yaw".to_string(), 9);
map.insert("left_hip_pitch".to_string(), 10);

// Right arm
map.insert("right_elbow_yaw".to_string(), 11);
map.insert("right_shoulder_yaw".to_string(), 12);
map.insert("right_shoulder_pitch".to_string(), 13);

// Left arm
map.insert("left_shoulder_pitch".to_string(), 14);
map.insert("left_shoulder_yaw".to_string(), 15);
map.insert("left_elbow_yaw".to_string(), 16);

map
};

pub static ref ID_TO_JOINT_NAME: HashMap<u32, String> = {
let mut map = HashMap::new();
for (name, &id) in JOINT_NAME_TO_ID.iter() {
map.insert(id, name.clone());
}
map
};
}

impl ZBotActuator {
pub async fn new() -> Result<Self> {
let mut supervisor = FeetechSupervisor::new()?;

// Add the servo with ID 1

let actuator_list = [1, 2, 3, 4, 5,
6, 7, 8, 9, 10,
11, 12, 13,
14, 15, 16];
for id in actuator_list {
// Add the servos
for id in ZBOT_ALL_ACTUATOR_IDS {
supervisor
.add_servo(id, FeetechActuatorType::Sts3215)
.add_servo(id as u8, FeetechActuatorType::Sts3215)
.await?;
}

Expand Down
266 changes: 266 additions & 0 deletions runtime/kos_platform/src/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
use async_trait::async_trait;
use eyre::Result;
use kos_core::{
google_proto::longrunning::Operation,
hal::{Inference, InferenceState},
kos_proto::{
common::{ActionResponse, Error, ErrorCode},
inference::*,
},
};
use kinfer::{ModelRunner, MilkVModelRunner, kinfer_proto::ProtoIO};
use std::{sync::Arc, path::Path};
use tokio::sync::RwLock;
use tracing::{debug, error, info};

use crate::{actuator::{ZBotActuator, ZBOT_ALL_ACTUATOR_IDS}, imu::ZBotIMU};

const MILKV_STANDING_MODEL_PATH: &str = "/root/models/ppo_standing.cvimodel";

pub struct ZBotInference {
model: Arc<RwLock<Option<MilkVModelRunner>>>,
imu: Arc<ZBotIMU>,
actuator: Arc<ZBotActuator>,
state: Arc<RwLock<InferenceState>>,
}

impl ZBotInference {
pub fn new(imu: Arc<ZBotIMU>, actuator: Arc<ZBotActuator>) -> Self {
let model = match MilkVModelRunner::new(MILKV_STANDING_MODEL_PATH) {
Ok(model) => Some(model),
Err(e) => {
error!("Failed to load default model: {}", e);
None
}
};
Self {
model: Arc::new(RwLock::new(model)),
imu,
actuator,
state: Arc::new(RwLock::new(InferenceState::Stopped)),
}
}

async fn get_sensor_data(&self) -> Result<(IMUData, Vec<ActuatorStateResponse>)> {
let imu_data = self.imu.get_values().await?;
let actuator_data = self.actuator.get_actuators_state(ZBOT_ALL_ACTUATOR_IDS).await?;
Ok((imu_data, actuator_data))
}

async fn pack_inputs_to_proto(
&self,
imu_data: IMUData,
actuator_data: Vec<ActuatorStateResponse>,
) -> Result<ProtoIO> {
let model = self.model.read().await;
let model = model.as_ref().ok_or_else(|| eyre!("No model loaded"))?;

let input_schema = model.input_schema()?;
let mut proto_values = Vec::new();

for value_schema in input_schema.values {
match value_schema.value_type {
Some(ValueType::JointPositions(ref joint_positions_schema)) => {
// Create position map for quick lookup from id
let position_map: HashMap<u32, f64> = actuator_data.iter()
.filter_map(|state| state.position.map(|pos| (state.actuator_id, pos)))
.collect();

// Pack joint positions into proto value with the proper order
let joint_positions = JointPositionsValue {
values: joint_positions_schema.joint_names.iter()
.map(|joint_name| {
let value = JOINT_NAME_TO_ID.get(joint_name)
.and_then(|&id| position_map.get(&id))
.map(|&pos| pos as f32)
.unwrap_or(0.0);

JointPositionValue {
joint_name: joint_name.clone(),
value,
unit: joint_positions_schema.unit, // Use schema-defined unit
}
})
.collect(),
};

let proto_value = ProtoValue {
value: Some(EnumValue::JointPositions(joint_positions))
};

proto_values.push((value_schema.value_name, proto_value));
}
Some(ValueType::Imu(ref imu_schema)) => {
let mut imu_values = Vec::new();
if imu_schema.use_accelerometer {
imu_values.extend_from_slice(&[
imu_data.accel_x as f32,
imu_data.accel_y as f32,
imu_data.accel_z as f32,
]);
}
if imu_schema.use_gyroscope {
imu_values.extend_from_slice(&[
imu_data.gyro_x as f32,
imu_data.gyro_y as f32,
imu_data.gyro_z as f32,
]);
}
if imu_schema.use_magnetometer {
imu_values.extend_from_slice(&[
imu_data.mag_x as f32,
imu_data.mag_y as f32,
imu_data.mag_z as f32,
]);
}

proto_values.push((value_schema.value_name, imu_values));
},
Some(ValueType::VectorCommand(ref vector_command_schema)) => {
// Default command vector of specified dimension
// TODO: replace with actual command vector (teleop?)
let command = vec![0.0f32; vector_command_schema.dimensions as usize];
proto_values.push((value_schema.value_name, command));
},
_ => {
error!("Unsupported input value type in schema");
return Err(eyre!("Unsupported input value type"));
}
}
}

Ok(ProtoIO {
values: proto_values,
})
}

async fn unpack_outputs_from_proto(&self, output: ProtoIO) -> Result<Vec<ActuatorCommand>> {
let model = self.model.read().await;
let model = model.as_ref().ok_or_else(|| eyre::eyre!("No model loaded"))?;

let output_schema = model.output_schema()?;
let mut actuator_commands = Vec::new();

for (value_schema, proto_value) in output_schema.values.iter().zip(output.values.iter()) {
match &value_schema.value_type {
Some(ValueType::JointCommands(ref joint_commands_schema)) => {
if let Some(EnumValue::JointCommands(commands)) = &proto_value.value {
// Convert joint commands values to actuator commands
for joint_cmd in &commands.values {
// Look up actuator ID from joint name
if let Some(&actuator_id) = JOINT_NAME_TO_ID.get(&joint_cmd.joint_name) {
actuator_commands.push(ActuatorCommand {
actuator_id,
position: Some(joint_cmd.position as f64),
velocity: None, // TODO: Velocity control not implemented yet
});
}
}
}
}
_ => {
error!("Unsupported output value type in schema");
return Err(eyre!("Unsupported output value type"));
}
}
}

Ok(actuator_commands)
}

async fn run_inference_cycle(&self) -> Result<()> {
let model = self.model.read().await;
let model = model.as_ref().ok_or_else(|| eyre::eyre!("No model loaded"))?;

let (imu_data, actuator_data) = self.get_sensor_data().await?;

let inputs = self.pack_inputs_to_proto(imu_data, actuator_data).await?;
let outputs = model.run(inputs)?;
let actuator_commands = self.unpack_outputs_from_proto(outputs).await?;

self.actuator.command_actuators(actuator_commands).await?;

Ok(())
}
}

#[async_trait]
impl Inference for ZBotInference {
async fn load_model(&self, path: String) -> Result<ActionResponse> {
info!("Loading model from: {}", path);

let model = match MilkVModelRunner::new(Path::new(&path)) {
Ok(model) => model,
Err(e) => {
error!("Failed to load model: {}", e);
return Ok(ActionResponse {
success: false,
error: Some(Error {
code: ErrorCode::InvalidArgument as i32,
message: format!("Failed to load model: {}", e),
}),
});
}
};

let mut model_guard = self.model.write().await;
*model_guard = Some(model);

Ok(ActionResponse {
success: true,
error: None,
})
}

async fn start(&self) -> Result<ActionResponse> {
let mut state = self.state.write().await;

if *state == InferenceState::Running {
return Ok(ActionResponse {
success: false,
error: Some(Error {
code: ErrorCode::AlreadyExists as i32,
message: "Inference is already running".to_string(),
}),
});
}

// Start inference loop in background task
let self_clone = Arc::new(self.clone());
tokio::spawn(async move {
loop {
let state = self_clone.state.read().await;
if *state != InferenceState::Running {
break;
}

if let Err(e) = self_clone.run_inference_cycle().await {
error!("Inference cycle failed: {}", e);
}

tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
});

*state = InferenceState::Running;

Ok(ActionResponse {
success: true,
error: None,
})
}

async fn stop(&self) -> Result<ActionResponse> {
let mut state = self.state.write().await;
*state = InferenceState::Stopped;

Ok(ActionResponse {
success: true,
error: None,
})
}

async fn get_state(&self) -> Result<InferenceState> {
Ok(*self.state.read().await)
}
}
8 changes: 7 additions & 1 deletion runtime/kos_platform/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
mod actuator;
mod firmware;
mod imu;
mod inference;

pub use actuator::*;
pub use firmware::*;
pub use imu::*;

pub use inference::*;
use kos_core::kos_proto::actuator::actuator_service_server::ActuatorServiceServer;
use kos_core::kos_proto::imu::imu_service_server::ImuServiceServer;
use kos_core::services::{ActuatorServiceImpl, IMUServiceImpl, OperationsServiceImpl};
Expand Down Expand Up @@ -66,6 +67,11 @@ impl Platform for ZBotPlatform {
}
}

let inference = ZBotInference::new(imu, actuator);
services.push(ServiceEnum::Inference(InferenceServiceServer::new(
InferenceServiceImpl::new(Arc::new(inference)),
)));

Ok(services)
})
}
Expand Down
1 change: 0 additions & 1 deletion runtime/src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,3 @@ pub async fn run(model: Arc<Model>, robot: Arc<Robot>) -> Result<()> {
}



2 changes: 1 addition & 1 deletion runtime/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ async fn main() -> Result<()> {

// run controller
controller::run(model_arc, Arc::new(robot)).context("Controller encountered an error")
}
}