diff --git a/CMakeLists.txt b/CMakeLists.txt index 1de2725..be59d00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,7 @@ set(${PROJECT_NAME}_SOURCES src/ur/master_board.cpp src/ur/rt_state.cpp src/ur/messages.cpp + src/tcp_socket.cpp src/ur_driver.cpp src/ur_realtime_communication.cpp src/ur_communication.cpp diff --git a/include/ur_modern_driver/pipeline.h b/include/ur_modern_driver/pipeline.h index 9f487b1..0e5cfdb 100644 --- a/include/ur_modern_driver/pipeline.h +++ b/include/ur_modern_driver/pipeline.h @@ -23,6 +23,9 @@ public: virtual void stopConsumer() { } + virtual void onTimeout() + { + } virtual bool consume(shared_ptr product) = 0; }; @@ -59,6 +62,13 @@ public: con->stopConsumer(); } } + virtual void onTimeout() + { + for(auto &con : consumers_) + { + con->onTimeout(); + } + } bool consume(shared_ptr product) { @@ -93,6 +103,8 @@ template class Pipeline { private: + typedef std::chrono::high_resolution_clock Clock; + typedef Clock::time_point Time; IProducer& producer_; IConsumer& consumer_; BlockingReaderWriterQueue> queue_; @@ -129,6 +141,8 @@ private: { consumer_.setupConsumer(); unique_ptr product; + Time last_pkg = Clock::now(); + Time last_warn = last_pkg; while (running_) { // 16000us timeout was chosen because we should @@ -136,8 +150,18 @@ private: // 8ms so double it for some error margin if (!queue_.wait_dequeue_timed(product, std::chrono::milliseconds(16))) { + Time now = Clock::now(); + auto pkg_diff = now - last_pkg; + auto warn_diff = now - last_warn; + if(pkg_diff > std::chrono::seconds(1) && warn_diff > std::chrono::seconds(1)) + { + last_warn = now; + consumer_.onTimeout(); + } continue; } + + last_pkg = Clock::now(); if (!consumer_.consume(std::move(product))) break; } diff --git a/include/ur_modern_driver/ros/action_server.h b/include/ur_modern_driver/ros/action_server.h index dd11de4..398e4d5 100644 --- a/include/ur_modern_driver/ros/action_server.h +++ b/include/ur_modern_driver/ros/action_server.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -17,7 +18,7 @@ #include "ur_modern_driver/ros/trajectory_follower.h" -class ActionServer : public URRTPacketConsumer, public Service +class ActionServer : public Service //,public URRTPacketConsumer { private: typedef control_msgs::FollowJointTrajectoryAction Action; @@ -35,9 +36,11 @@ private: GoalHandle curr_gh_; + std::atomic interrupt_traj_; std::atomic has_goal_, running_; std::mutex tj_mutex_; std::condition_variable tj_cv_; + std::thread tj_thread_; TrajectoryFollower& follower_; @@ -50,20 +53,16 @@ private: bool validateTrajectory(GoalHandle& gh, Result& res); bool try_execute(GoalHandle& gh, Result& res); + void interruptGoal(GoalHandle& gh); std::vector reorderMap(std::vector goal_joints); - double interp_cubic(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel); void trajectoryThread(); - template - double toSec(U const& u) - { - return std::chrono::duration_cast>(u).count(); - } public: ActionServer(TrajectoryFollower& follower, std::vector& joint_names, double max_velocity); + void start(); virtual void onRobotStateChange(RobotState state); }; \ No newline at end of file diff --git a/include/ur_modern_driver/ros/controller.h b/include/ur_modern_driver/ros/controller.h index eb940d0..3532ee6 100644 --- a/include/ur_modern_driver/ros/controller.h +++ b/include/ur_modern_driver/ros/controller.h @@ -34,6 +34,7 @@ private: std::map available_interfaces_; std::atomic service_enabled_; + std::atomic service_cooldown_; // helper functions to map interfaces template @@ -51,6 +52,7 @@ private: void read(RTShared& state); bool update(RTShared& state); bool write(); + void reset(); public: ROSController(URCommander& commander, std::vector& joint_names, double max_vel_change); diff --git a/include/ur_modern_driver/ros/hardware_interface.h b/include/ur_modern_driver/ros/hardware_interface.h index 0c6618d..325faa4 100644 --- a/include/ur_modern_driver/ros/hardware_interface.h +++ b/include/ur_modern_driver/ros/hardware_interface.h @@ -13,6 +13,7 @@ public: virtual bool write() = 0; virtual void start() {} virtual void stop() {} + virtual void reset() {} }; using hardware_interface::JointHandle; @@ -48,6 +49,7 @@ private: public: VelocityInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector &joint_names, double max_vel_change); virtual bool write(); + virtual void reset(); typedef hardware_interface::VelocityJointInterface parent_type; }; diff --git a/include/ur_modern_driver/ros/io_service.h b/include/ur_modern_driver/ros/io_service.h index 36ce662..03d9935 100644 --- a/include/ur_modern_driver/ros/io_service.h +++ b/include/ur_modern_driver/ros/io_service.h @@ -37,6 +37,8 @@ private: case ur_msgs::SetIO::Request::FUN_SET_FLAG: res = commander_.setFlag(req.pin, flag); break; + default: + LOG_WARN("Invalid setIO function called (%d)", req.fun); } return (resp.success = res); diff --git a/include/ur_modern_driver/ros/trajectory_follower.h b/include/ur_modern_driver/ros/trajectory_follower.h index a1df97f..a16c3e5 100644 --- a/include/ur_modern_driver/ros/trajectory_follower.h +++ b/include/ur_modern_driver/ros/trajectory_follower.h @@ -5,21 +5,39 @@ #include #include #include +#include #include +#include "ur_modern_driver/log.h" #include "ur_modern_driver/ur/commander.h" #include "ur_modern_driver/ur/server.h" -#include "ur_modern_driver/ur/stream.h" + +struct TrajectoryPoint +{ + std::array positions; + std::array velocities; + std::chrono::microseconds time_from_start; + + TrajectoryPoint() + { + } + + TrajectoryPoint(std::array &pos, std::array &vel, std::chrono::microseconds tfs) + : positions(pos) + , velocities(vel) + , time_from_start(tfs) + { + } +}; class TrajectoryFollower { private: - const int32_t MULT_JOINTSTATE_ = 1000000; double servoj_time_, servoj_lookahead_time_, servoj_gain_; std::atomic running_; std::array last_positions_; URCommander &commander_; URServer server_; - URStream stream_; + int reverse_port_; std::string program_; template @@ -30,15 +48,16 @@ private: return s; } + std::string buildProgram(); bool execute(std::array &positions, bool keep_alive); + double interpolate(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel); public: TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3); - std::string buildProgram(bool version_3); - bool start(); bool execute(std::array &positions); + bool execute(std::vector &trajectory, std::atomic &interrupt); void stop(); - void halt(); //maybe + void interrupt(); }; \ No newline at end of file diff --git a/include/ur_modern_driver/tcp_socket.h b/include/ur_modern_driver/tcp_socket.h new file mode 100644 index 0000000..303d49e --- /dev/null +++ b/include/ur_modern_driver/tcp_socket.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +enum class SocketState +{ + Invalid, + Connected, + Disconnected, + Closed +}; + +class TCPSocket +{ +private: + std::atomic socket_fd_; + std::atomic state_; + +protected: + virtual bool open(int socket_fd, struct sockaddr *address, size_t address_len) + { + return false; + } + + + bool setup(std::string &host, int port); + void close(); + + +public: + TCPSocket(); + virtual ~TCPSocket(); + + SocketState getState() { return state_; } + + int getSocketFD() { return socket_fd_; } + bool setSocketFD(int socket_fd); + + bool read(uint8_t* buf, size_t buf_len, size_t &read); + bool write(const uint8_t* buf, size_t buf_len, size_t &written); +}; diff --git a/include/ur_modern_driver/ur/commander.h b/include/ur_modern_driver/ur/commander.h index da92ae8..5d81698 100644 --- a/include/ur_modern_driver/ur/commander.h +++ b/include/ur_modern_driver/ur/commander.h @@ -11,18 +11,46 @@ private: protected: bool write(std::string& s); + void formatArray(std::ostringstream &out, std::array &values); public: URCommander(URStream& stream) : stream_(stream) { } - virtual bool uploadProg(std::string &s); + virtual bool speedj(std::array &speeds, double acceleration) = 0; + virtual bool setDigitalOut(uint8_t pin, bool value) = 0; + virtual bool setAnalogOut(uint8_t pin, double value) = 0; + + //shared + bool uploadProg(std::string &s); + bool stopj(double a = 10.0); + bool setToolVoltage(uint8_t voltage); + bool setFlag(uint8_t pin, bool value); + bool setPayload(double value); +}; + +class URCommander_V1_X : public URCommander +{ +public: + URCommander_V1_X(URStream& stream) : URCommander(stream) + { + } + + virtual bool speedj(std::array &speeds, double acceleration); + virtual bool setDigitalOut(uint8_t pin, bool value); + virtual bool setAnalogOut(uint8_t pin, double value); +}; + + +class URCommander_V3_X : public URCommander +{ +public: + URCommander_V3_X(URStream& stream) : URCommander(stream) + { + } + virtual bool speedj(std::array &speeds, double acceleration); - virtual bool stopj(double a = 10.0); virtual bool setDigitalOut(uint8_t pin, bool value); virtual bool setAnalogOut(uint8_t pin, double value); - virtual bool setToolVoltage(uint8_t voltage); - virtual bool setFlag(uint8_t pin, bool value); - virtual bool setPayload(double value); }; \ No newline at end of file diff --git a/include/ur_modern_driver/ur/factory.h b/include/ur_modern_driver/ur/factory.h index 83f0cd6..0b4e6ca 100644 --- a/include/ur_modern_driver/ur/factory.h +++ b/include/ur_modern_driver/ur/factory.h @@ -71,6 +71,19 @@ public: prod.teardownProducer(); } + bool isVersion3() + { + return major_version_ == 3; + } + + std::unique_ptr getCommander(URStream &stream) + { + if(major_version_ == 1) + return std::unique_ptr(new URCommander_V1_X(stream)); + else + return std::unique_ptr(new URCommander_V3_X(stream)); + } + std::unique_ptr> getStateParser() { if (major_version_ == 1) diff --git a/include/ur_modern_driver/ur/producer.h b/include/ur_modern_driver/ur/producer.h index 5907f97..b3bb475 100644 --- a/include/ur_modern_driver/ur/producer.h +++ b/include/ur_modern_driver/ur/producer.h @@ -1,4 +1,5 @@ #pragma once +#include #include "ur_modern_driver/pipeline.h" #include "ur_modern_driver/ur/parser.h" #include "ur_modern_driver/ur/stream.h" @@ -9,9 +10,10 @@ class URProducer : public IProducer private: URStream& stream_; URParser& parser_; + std::chrono::seconds timeout_; public: - URProducer(URStream& stream, URParser& parser) : stream_(stream), parser_(parser) + URProducer(URStream& stream, URParser& parser) : stream_(stream), parser_(parser), timeout_(1) { } @@ -32,24 +34,29 @@ public: { // 4KB should be enough to hold any packet received from UR uint8_t buf[4096]; - - // blocking call - ssize_t len = stream_.receive(buf, sizeof(buf)); - - // LOG_DEBUG("Read %d bytes from stream", len); - - if (len == 0) + size_t read = 0; + //expoential backoff reconnects + while(true) { - LOG_WARN("Read nothing from stream"); - return false; - } - else if (len < 0) - { - LOG_WARN("Stream closed"); - return false; + if(stream_.read(buf, sizeof(buf), read)) + { + //reset sleep amount + timeout_ = std::chrono::seconds(1); + break; + } + + if(stream_.closed()) + return false; + + LOG_WARN("Failed to read from stream, reconnecting in %ld seconds...", timeout_.count()); + std::this_thread::sleep_for(timeout_); + auto next = timeout_ * 2; + if(next <= std::chrono::seconds(120)) + timeout_ = next; } - BinParser bp(buf, static_cast(len)); + + BinParser bp(buf, read); return parser_.parse(bp, products); } }; \ No newline at end of file diff --git a/include/ur_modern_driver/ur/server.h b/include/ur_modern_driver/ur/server.h index 8740806..dba5fe5 100644 --- a/include/ur_modern_driver/ur/server.h +++ b/include/ur_modern_driver/ur/server.h @@ -6,14 +6,25 @@ #include #include #include -#include "ur_modern_driver/ur/stream.h" +#include "ur_modern_driver/tcp_socket.h" -class URServer +class URServer : private TCPSocket { private: - int socket_fd_ = -1; + int port_; + SocketState state_; + TCPSocket client_; + +protected: + virtual bool open(int socket_fd, struct sockaddr *address, size_t address_len) + { + return ::bind(socket_fd, address, address_len) == 0; + } public: URServer(int port); - URStream accept(); + std::string getIP(); + bool bind(); + bool accept(); + bool write(const uint8_t* buf, size_t buf_len, size_t &written); }; \ No newline at end of file diff --git a/include/ur_modern_driver/ur/stream.h b/include/ur_modern_driver/ur/stream.h index 16719d7..e3bbb23 100644 --- a/include/ur_modern_driver/ur/stream.h +++ b/include/ur_modern_driver/ur/stream.h @@ -5,56 +5,39 @@ #include #include #include +#include "ur_modern_driver/log.h" +#include "ur_modern_driver/tcp_socket.h" -/// Encapsulates a TCP socket -class URStream +class URStream : private TCPSocket { private: - int socket_fd_ = -1; std::string host_; int port_; + std::mutex write_mutex_, read_mutex_; - std::atomic initialized_; - std::atomic stopping_; - std::mutex send_mutex_, receive_mutex_; +protected: + virtual bool open(int socket_fd, struct sockaddr *address, size_t address_len) + { + return ::connect(socket_fd, address, address_len) == 0; + } public: - URStream() + URStream(std::string& host, int port) : host_(host), port_(port) { } - URStream(std::string& host, int port) : host_(host), port_(port), initialized_(false), stopping_(false) + bool connect() { + return TCPSocket::setup(host_, port_); + } + void disconnect() + { + LOG_INFO("Disconnecting"); + TCPSocket::close(); } - URStream(int socket_fd) : socket_fd_(socket_fd), initialized_(true), stopping_(false) - { - - } + bool closed() { return getState() == SocketState::Closed; } - URStream(URStream&& other) noexcept : socket_fd_(other.socket_fd_), host_(other.host_), initialized_(other.initialized_.load()), stopping_(other.stopping_.load()) - { - - } - - ~URStream() - { - disconnect(); - } - - URStream& operator=(URStream&& other) - { - socket_fd_ = std::move(other.socket_fd_); - host_ = std::move(other.host_); - initialized_ = std::move(other.initialized_.load()); - stopping_ = std::move(other.stopping_.load()); - return *this; - } - - bool connect(); - void disconnect(); - void reconnect(); - - ssize_t send(const uint8_t* buf, size_t buf_len); - ssize_t receive(uint8_t* buf, size_t buf_len); + bool read(uint8_t* buf, size_t buf_len, size_t &read); + bool write(const uint8_t* buf, size_t buf_len, size_t &written); }; \ No newline at end of file diff --git a/src/ros/action_server.cpp b/src/ros/action_server.cpp index 3b6637d..629a50f 100644 --- a/src/ros/action_server.cpp +++ b/src/ros/action_server.cpp @@ -15,7 +15,14 @@ ActionServer::ActionServer(TrajectoryFollower& follower, std::vector lock(tj_mutex_); + Result res; + res.error_code = -100; + res.error_string = "Goal cancelled by client"; + gh.setCanceled(res); } bool ActionServer::validate(GoalHandle& gh, Result& res) @@ -125,9 +139,17 @@ bool ActionServer::validateTrajectory(GoalHandle& gh, Result& res) } } + //todo validate start position? + return true; } +inline std::chrono::microseconds convert(const ros::Duration &dur) +{ + return std::chrono::duration_cast(std::chrono::seconds(dur.sec)) + + std::chrono::duration_cast(std::chrono::nanoseconds(dur.nsec)); +} + bool ActionServer::try_execute(GoalHandle& gh, Result& res) { if(!running_) @@ -137,27 +159,20 @@ bool ActionServer::try_execute(GoalHandle& gh, Result& res) } if(!tj_mutex_.try_lock()) { - has_goal_ = false; - //stop_trajectory(); + interrupt_traj_ = true; res.error_string = "Received another trajectory"; curr_gh_.setAborted(res, res.error_string); tj_mutex_.lock(); + //todo: make configurable + std::this_thread::sleep_for(std::chrono::milliseconds(250)); } //locked here curr_gh_ = gh; + interrupt_traj_ = false; has_goal_ = true; tj_mutex_.unlock(); tj_cv_.notify_one(); -} - -inline double ActionServer::interp_cubic(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel) -{ - using std::pow; - double a = p0_pos; - double b = p0_vel; - double c = (-3 * a + 3 * p1_pos - 2 * T * b - T * p1_vel) / pow(T, 2); - double d = (2 * a - 2 * p1_pos + T * b + T * p1_vel) / pow(T, 3); - return a + b * t + c * pow(t, 2) + d * pow(t, 3); + return true; } std::vector ActionServer::reorderMap(std::vector goal_joints) @@ -179,53 +194,56 @@ std::vector ActionServer::reorderMap(std::vector goal_joint void ActionServer::trajectoryThread() { + follower_.start(); //todo check error + //as_.start(); while(running_) { std::unique_lock lk(tj_mutex_); if(!tj_cv_.wait_for(lk, std::chrono::milliseconds(100), [&]{return running_ && has_goal_;})) continue; - auto g = curr_gh_.getGoal(); - auto const& traj = g->trajectory; - auto const& points = traj.points; - size_t len = points.size(); - auto const& last_point = points[points.size() - 1]; - double end_time = last_point.time_from_start.toSec(); + LOG_DEBUG("Trajectory received and accepted"); + curr_gh_.setAccepted(); + + auto goal = curr_gh_.getGoal(); + auto mapping = reorderMap(goal->trajectory.joint_names); + std::vector trajectory(goal->trajectory.points.size()); - auto mapping = reorderMap(traj.joint_names); - std::chrono::high_resolution_clock::time_point t0, t; - t = t0 = std::chrono::high_resolution_clock::now(); - - size_t i = 0; - while(end_time >= toSec(t - t0) && has_goal_) + for(auto const& point : goal->trajectory.points) { - while(points[i].time_from_start.toSec() <= toSec(t - t0) && i < len) - i++; - - auto const& pp = points[i-1]; - auto const& p = points[i]; - - auto pp_t = pp.time_from_start.toSec(); - auto p_t =p.time_from_start.toSec(); - - std::array pos; - for(size_t j = 0; j < pos.size(); j++) + std::array pos, vel; + for(size_t i = 0; i < 6; i++) { - pos[i] = interp_cubic( - toSec(t - t0) - pp_t, - p_t - pp_t, - pp.positions[j], - p.positions[j], - pp.velocities[j], - p.velocities[j] - ); + //joint names of the goal might have a different ordering compared + //to what URScript expects so need to map between the two + size_t idx = mapping[i]; + pos[idx] = point.positions[i]; + vel[idx] = point.velocities[i]; } + trajectory.push_back(TrajectoryPoint(pos, vel, convert(point.time_from_start))); + } - follower_.execute(pos); - //std::this_thread::sleep_for(std::chrono::milliseconds((int)((servoj_time_ * 1000) / 4.))); - t = std::chrono::high_resolution_clock::now(); + Result res; + if(follower_.execute(trajectory, interrupt_traj_)) + { + //interrupted goals must be handled by interrupt trigger + if(!interrupt_traj_) + { + LOG_DEBUG("Trajectory executed successfully"); + res.error_code = Result::SUCCESSFUL; + curr_gh_.setSucceeded(res); + } + } + else + { + LOG_DEBUG("Trajectory failed"); + res.error_code = -100; + res.error_string = "Connection to robot was lost"; + curr_gh_.setAborted(res, res.error_string); } has_goal_ = false; + lk.unlock(); } + follower_.stop(); } \ No newline at end of file diff --git a/src/ros/controller.cpp b/src/ros/controller.cpp index 753f28d..79102fa 100644 --- a/src/ros/controller.cpp +++ b/src/ros/controller.cpp @@ -20,6 +20,8 @@ void ROSController::setupConsumer() void ROSController::doSwitch(const std::list& start_list, const std::list& stop_list) { + LOG_INFO("Switching hardware interface"); + if (active_interface_ != nullptr && stop_list.size() > 0) { LOG_INFO("Stopping active interface"); @@ -54,6 +56,14 @@ bool ROSController::write() return active_interface_->write(); } +void ROSController::reset() +{ + if (active_interface_ == nullptr) + return; + + active_interface_->reset(); +} + void ROSController::read(RTShared& packet) { joint_interface_.update(packet); @@ -68,17 +78,32 @@ bool ROSController::update(RTShared& state) lastUpdate_ = time; read(state); - controller_.update(time, diff); + controller_.update(time, diff, !service_enabled_); //emergency stop and such should not kill the pipeline //but still prevent writes if(!service_enabled_) + { + reset(); return true; + } + + //allow the controller to update x times before allowing writes again + if(service_cooldown_ > 0) + { + service_cooldown_ -= 1; + return true; + } return write(); } void ROSController::onRobotStateChange(RobotState state) { - service_enabled_ = (state == RobotState::Running); + bool next = (state == RobotState::Running); + if(next == service_enabled_) + return; + + service_enabled_ = next; + service_cooldown_ = 125; } \ No newline at end of file diff --git a/src/ros/hardware_interface.cpp b/src/ros/hardware_interface.cpp index 1bf0cc4..09a0798 100644 --- a/src/ros/hardware_interface.cpp +++ b/src/ros/hardware_interface.cpp @@ -1,61 +1,69 @@ #include "ur_modern_driver/ros/hardware_interface.h" +#include "ur_modern_driver/log.h" JointInterface::JointInterface(std::vector &joint_names) { - for (size_t i = 0; i < 6; i++) - { - registerHandle(hardware_interface::JointStateHandle(joint_names[i], &positions_[i], &velocities_[i], &efforts_[i])); - } + for (size_t i = 0; i < 6; i++) + { + registerHandle(hardware_interface::JointStateHandle(joint_names[i], &positions_[i], &velocities_[i], &efforts_[i])); + } } void JointInterface::update(RTShared &packet) { - positions_ = packet.q_actual; - velocities_ = packet.qd_actual; - efforts_ = packet.i_actual; + positions_ = packet.q_actual; + velocities_ = packet.qd_actual; + efforts_ = packet.i_actual; } WrenchInterface::WrenchInterface() { - registerHandle(hardware_interface::ForceTorqueSensorHandle("wrench", "", tcp_.begin(), tcp_.begin() + 3)); -} + registerHandle(hardware_interface::ForceTorqueSensorHandle("wrench", "", tcp_.begin(), tcp_.begin() + 3)); +} void WrenchInterface::update(RTShared &packet) { - tcp_ = packet.tcp_force; + tcp_ = packet.tcp_force; } VelocityInterface::VelocityInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector &joint_names, double max_vel_change) - : commander_(commander), max_vel_change_(max_vel_change) + : commander_(commander), max_vel_change_(max_vel_change) { - for (size_t i = 0; i < 6; i++) - { - registerHandle(JointHandle(js_interface.getHandle(joint_names[i]), &velocity_cmd_[i])); - } + for (size_t i = 0; i < 6; i++) + { + registerHandle(JointHandle(js_interface.getHandle(joint_names[i]), &velocity_cmd_[i])); + } } bool VelocityInterface::write() { - for (size_t i = 0; i < 6; i++) - { - double prev = prev_velocity_cmd_[i]; - double lo = prev - max_vel_change_; - double hi = prev + max_vel_change_; - // clamp value to ±max_vel_change - prev_velocity_cmd_[i] = std::max(lo, std::min(velocity_cmd_[i], hi)); - } + for (size_t i = 0; i < 6; i++) + { + // clamp value to ±max_vel_change + double prev = prev_velocity_cmd_[i]; + double lo = prev - max_vel_change_; + double hi = prev + max_vel_change_; + prev_velocity_cmd_[i] = std::max(lo, std::min(velocity_cmd_[i], hi)); + } + return commander_.speedj(prev_velocity_cmd_, max_vel_change_); +} - return commander_.speedj(prev_velocity_cmd_, max_vel_change_); +void VelocityInterface::reset() +{ + for(auto &val : prev_velocity_cmd_) + { + val = 0; + } } PositionInterface:: PositionInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector &joint_names) - : commander_(commander) + : commander_(commander) { - for (size_t i = 0; i < 6; i++) - { - registerHandle(JointHandle(js_interface.getHandle(joint_names[i]), &position_cmd_[i])); - } + for (size_t i = 0; i < 6; i++) + { + registerHandle(JointHandle(js_interface.getHandle(joint_names[i]), &position_cmd_[i])); + } } bool PositionInterface::write() diff --git a/src/ros/trajectory_follower.cpp b/src/ros/trajectory_follower.cpp index ca36e59..53d966a 100644 --- a/src/ros/trajectory_follower.cpp +++ b/src/ros/trajectory_follower.cpp @@ -1,83 +1,96 @@ #include #include "ur_modern_driver/ros/trajectory_follower.h" - - -TrajectoryFollower::TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3) - : running_(false) - , commander_(commander) - , server_(reverse_port) - , program_(buildProgram(version_3)) -{ -} + + +static const int32_t MULT_JOINTSTATE_ = 1000000; static const std::string JOINT_STATE_REPLACE("{{JOINT_STATE_REPLACE}}"); static const std::string SERVO_J_REPLACE("{{SERVO_J_REPLACE}}"); +static const std::string SERVER_IP_REPLACE("{{SERVER_IP_REPLACE}}"); +static const std::string SERVER_PORT_REPLACE("{{SERVER_PORT_REPLACE}}"); static const std::string POSITION_PROGRAM = R"( def driverProg(): - MULT_jointstate = {{JOINT_STATE_REPLACE}} + MULT_jointstate = {{JOINT_STATE_REPLACE}} - SERVO_IDLE = 0 - SERVO_RUNNING = 1 - cmd_servo_state = SERVO_IDLE - cmd_servo_q = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - def set_servo_setpoint(q): - enter_critical - cmd_servo_state = SERVO_RUNNING - cmd_servo_q = q - exit_critical - end - thread servoThread(): - state = SERVO_IDLE - while True: - enter_critical - q = cmd_servo_q - do_brake = False - if (state == SERVO_RUNNING) and (cmd_servo_state == SERVO_IDLE): - do_brake = True - end - state = cmd_servo_state - cmd_servo_state = SERVO_IDLE - exit_critical - if do_brake: - stopj(1.0) - sync() - elif state == SERVO_RUNNING: - servoj(q, {{SERVO_J_REPLACE}}) - else: - sync() - end - end - end + SERVO_IDLE = 0 + SERVO_RUNNING = 1 + cmd_servo_state = SERVO_IDLE + cmd_servo_q = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + + def set_servo_setpoint(q): + enter_critical + cmd_servo_state = SERVO_RUNNING + cmd_servo_q = q + exit_critical + end + + thread servoThread(): + state = SERVO_IDLE + while True: + enter_critical + q = cmd_servo_q + do_brake = False + if (state == SERVO_RUNNING) and (cmd_servo_state == SERVO_IDLE): + do_brake = True + end + state = cmd_servo_state + cmd_servo_state = SERVO_IDLE + exit_critical + if do_brake: + stopj(1.0) + sync() + elif state == SERVO_RUNNING: + servoj(q, {{SERVO_J_REPLACE}}) + else: + sync() + end + end + end + + socket_open(\"{{SERVER_IP_REPLACE}}\", {{SERVER_PORT_REPLACE}}) thread_servo = run servoThread() keepalive = 1 while keepalive > 0: - params_mult = socket_read_binary_integer(6+1) - if params_mult[0] > 0: - q = [params_mult[1] / MULT_jointstate, params_mult[2] / MULT_jointstate, params_mult[3] / MULT_jointstate, params_mult[4] / MULT_jointstate, params_mult[5] / MULT_jointstate, params_mult[6] / MULT_jointstate] - keepalive = params_mult[7] - set_servo_setpoint(q) - end + params_mult = socket_read_binary_integer(6+1) + if params_mult[0] > 0: + q = [params_mult[1] / MULT_jointstate, params_mult[2] / MULT_jointstate, params_mult[3] / MULT_jointstate, params_mult[4] / MULT_jointstate, params_mult[5] / MULT_jointstate, params_mult[6] / MULT_jointstate] + keepalive = params_mult[7] + set_servo_setpoint(q) + end end sleep(.1) socket_close() kill thread_servo end )"; -std::string TrajectoryFollower::buildProgram(bool version_3) + +TrajectoryFollower::TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3) + : running_(false) + , commander_(commander) + , reverse_port_(reverse_port) + , server_(reverse_port) { std::string res(POSITION_PROGRAM); - size_t js_idx = POSITION_PROGRAM.find(JOINT_STATE_REPLACE); - size_t sj_idx = POSITION_PROGRAM.find(SERVO_J_REPLACE); + res.replace(res.find(JOINT_STATE_REPLACE), JOINT_STATE_REPLACE.length(), std::to_string(MULT_JOINTSTATE_)); std::ostringstream out; out << "t=" << std::fixed << std::setprecision(4) << servoj_time_; - if(version_3) out << ", lookahead_time=" << servoj_lookahead_time_ << ", gain=" << servoj_gain_; - res.replace(js_idx, JOINT_STATE_REPLACE.length(), std::to_string(MULT_JOINTSTATE_)); - res.replace(sj_idx, SERVO_J_REPLACE.length(), out.str()); + res.replace(res.find(SERVO_J_REPLACE), SERVO_J_REPLACE.length(), out.str()); + + program_ = res; +} + +std::string TrajectoryFollower::buildProgram() +{ + std::string res(program_); + std::string IP(server_.getIP()); + LOG_INFO("Local IP: %s ", IP.c_str()); + res.replace(res.find(SERVER_IP_REPLACE), SERVER_IP_REPLACE.length(), "127.0.0.1"); + res.replace(res.find(SERVER_PORT_REPLACE), SERVER_PORT_REPLACE.length(), std::to_string(reverse_port_)); return res; } @@ -86,12 +99,31 @@ bool TrajectoryFollower::start() if(running_) return true; //not sure - //TODO - std::string prog(""); // buildProg(); - if(!commander_.uploadProg(prog)) + if(!server_.bind()) + { + LOG_ERROR("Failed to bind server"); return false; + } - stream_ = std::move(server_.accept()); //todo: pointer instead? + LOG_INFO("Uploading trajectory program to robot"); + + std::string prog(buildProgram()); + //std::string prog = "socket_open(\"127.0.0.1\", 50001)\n"; + if(!commander_.uploadProg(prog)) + { + LOG_ERROR("Program upload failed!"); + return false; + } + + LOG_INFO("Awaiting incomming robot connection"); + + if(!server_.accept()) + { + LOG_ERROR("Failed to accept incomming robot connection"); + return false; + } + + LOG_INFO("Robot successfully connected"); return (running_ = true); } @@ -115,8 +147,18 @@ bool TrajectoryFollower::execute(std::array &positions, bool keep_ali int32_t val = htobe32(static_cast(keep_alive)); append(idx, val); - ssize_t res = stream_.send(buf, sizeof(buf)); - return res > 0 && res == sizeof(buf); + size_t written; + return server_.write(buf, sizeof(buf), written); +} + +double TrajectoryFollower::interpolate(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel) +{ + using std::pow; + double a = p0_pos; + double b = p0_vel; + double c = (-3 * a + 3 * p1_pos - 2 * T * b - T * p1_vel) / pow(T, 2); + double d = (2 * a - 2 * p1_pos + T * b + T * p1_vel) / pow(T, 3); + return a + b * t + c * pow(t, 2) + d * pow(t, 3); } bool TrajectoryFollower::execute(std::array &positions) @@ -124,6 +166,68 @@ bool TrajectoryFollower::execute(std::array &positions) return execute(positions, true); } +bool TrajectoryFollower::execute(std::vector &trajectory, std::atomic &interrupt) +{ + if(!running_) + return false; + + using namespace std::chrono; + typedef duration double_seconds; + typedef high_resolution_clock Clock; + typedef Clock::time_point Time; + + auto const& last = trajectory[trajectory.size()-1]; + auto& prev = trajectory[0]; + + Time t0 = Clock::now(); + Time latest = t0; + + std::array positions; + + for(auto const& point : trajectory) + { + //skip t0 + if(&point == &prev) + continue; + + auto duration = point.time_from_start - prev.time_from_start; + double d_s = duration_cast(duration).count(); + + //interpolation loop + while(!interrupt) + { + latest = Clock::now(); + auto elapsed = latest - t0; + + if(point.time_from_start <= elapsed || last.time_from_start >= elapsed) + break; + + double elapsed_s = duration_cast(elapsed - prev.time_from_start).count(); + //double prev_seconds + for(size_t j = 0; j < positions.size(); j++) + { + positions[j] = interpolate( + elapsed_s, + d_s, + prev.positions[j], + point.positions[j], + prev.velocities[j], + point.velocities[j] + ); + } + + if(!execute(positions, true)) + return false; + + std::this_thread::sleep_for(double_seconds(servoj_time_)); + } + + prev = point; + } + + return true; +} + void TrajectoryFollower::stop() { if(!running_) @@ -132,6 +236,6 @@ void TrajectoryFollower::stop() std::array empty; execute(empty, false); - stream_.disconnect(); + //server_.disconnect(); running_ = false; } \ No newline at end of file diff --git a/src/ros_main.cpp b/src/ros_main.cpp index 7f76792..efaf106 100644 --- a/src/ros_main.cpp +++ b/src/ros_main.cpp @@ -6,10 +6,12 @@ #include "ur_modern_driver/log.h" #include "ur_modern_driver/pipeline.h" +#include "ur_modern_driver/ros/action_server.h" +#include "ur_modern_driver/ros/controller.h" #include "ur_modern_driver/ros/io_service.h" #include "ur_modern_driver/ros/mb_publisher.h" -#include "ur_modern_driver/ros/controller.h" #include "ur_modern_driver/ros/rt_publisher.h" +#include "ur_modern_driver/ros/trajectory_follower.h" #include "ur_modern_driver/ros/service_stopper.h" #include "ur_modern_driver/ur/commander.h" #include "ur_modern_driver/ur/factory.h" @@ -63,7 +65,7 @@ bool parse_args(ProgArgs &args) return true; } -#include "ur_modern_driver/event_counter.h" +#include "ur_modern_driver/ur/server.h" int main(int argc, char **argv) { @@ -75,6 +77,7 @@ int main(int argc, char **argv) return EXIT_FAILURE; } + URFactory factory(args.host); vector services; @@ -84,17 +87,27 @@ int main(int argc, char **argv) URStream rt_stream(args.host, UR_RT_PORT); URProducer rt_prod(rt_stream, *rt_parser); RTPublisher rt_pub(args.prefix, args.base_frame, args.tool_frame, args.use_ros_control); - URCommander rt_commander(rt_stream); + auto rt_commander = factory.getCommander(rt_stream); vector *> rt_vec{&rt_pub}; + TrajectoryFollower traj_follower(*rt_commander, args.reverse_port, factory.isVersion3()); + ROSController *controller(nullptr); + ActionServer *action_server(nullptr); if (args.use_ros_control) { LOG_INFO("ROS control enabled"); - controller = new ROSController(rt_commander, args.joint_names, args.max_vel_change); + controller = new ROSController(*rt_commander, args.joint_names, args.max_vel_change); rt_vec.push_back(controller); services.push_back(controller); } + else + { + LOG_INFO("ActionServer enabled"); + action_server = new ActionServer(traj_follower, args.joint_names, args.max_velocity); + //rt_vec.push_back(action_server); + services.push_back(action_server); + } MultiConsumer rt_cons(rt_vec); Pipeline rt_pl(rt_prod, rt_cons); @@ -116,8 +129,11 @@ int main(int argc, char **argv) rt_pl.run(); state_pl.run(); - URCommander state_commander(state_stream); - IOService io_service(state_commander); + auto state_commander = factory.getCommander(state_stream); + IOService io_service(*state_commander); + + if(action_server) + action_server->start(); ros::spin(); diff --git a/src/tcp_socket.cpp b/src/tcp_socket.cpp new file mode 100644 index 0000000..01685bf --- /dev/null +++ b/src/tcp_socket.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include + +#include "ur_modern_driver/log.h" +#include "ur_modern_driver/tcp_socket.h" + +TCPSocket::TCPSocket() + : socket_fd_(-1) + , state_(SocketState::Invalid) +{ +} +TCPSocket::~TCPSocket() +{ + close(); +} + +bool TCPSocket::setup(std::string &host, int port) +{ + if(state_ == SocketState::Connected) + return false; + + LOG_INFO("Setting up connection: %s:%d", host.c_str(), port); + + // gethostbyname() is deprecated so use getadderinfo() as described in: + // http://www.beej.us/guide/bgnet/output/html/multipage/syscalls.html#getaddrinfo + + const char *host_name = host.empty() ? nullptr : host.c_str(); + std::string service = std::to_string(port); + struct addrinfo hints, *result; + std::memset(&hints, 0, sizeof(hints)); + + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + if (getaddrinfo(host_name, service.c_str(), &hints, &result) != 0) + { + LOG_ERROR("Failed to get address for %s:%d", host.c_str(), port); + return false; + } + + bool connected = false; + // loop through the list of addresses untill we find one that's connectable + for (struct addrinfo* p = result; p != nullptr; p = p->ai_next) + { + socket_fd_ = ::socket(p->ai_family, p->ai_socktype, p->ai_protocol); + + if (socket_fd_ != -1 && open(socket_fd_, p->ai_addr, p->ai_addrlen)) + { + connected = true; + break; + } + } + + freeaddrinfo(result); + + if(!connected) + { + state_ = SocketState::Invalid; + LOG_ERROR("Connection setup failed for %s:%d", host.c_str(), port); + } + else + { + state_ = SocketState::Connected; + LOG_INFO("Connection established for %s:%d", host.c_str(), port); + } + return connected; +} + +bool TCPSocket::setSocketFD(int socket_fd) +{ + if(state_ == SocketState::Connected) + return false; + socket_fd_ = socket_fd; + state_ = SocketState::Connected; + return true; +} + +void TCPSocket::close() +{ + if(state_ != SocketState::Connected) + return; + state_ = SocketState::Closed; + ::shutdown(socket_fd_, SHUT_RDWR); + socket_fd_ = -1; +} + +bool TCPSocket::read(uint8_t *buf, size_t buf_len, size_t &read) +{ + read = 0; + + if(state_ != SocketState::Connected) + return false; + + ssize_t res = ::recv(socket_fd_, buf, buf_len, 0); + + if(res == 0) + { + state_ = SocketState::Disconnected; + return false; + } + else if(res < 0) + return false; + + read = static_cast(res); + return true; +} + +bool TCPSocket::write(const uint8_t *buf, size_t buf_len, size_t &written) +{ + written = 0; + + if(state_ != SocketState::Connected) + return false; + + size_t remaining = buf_len; + + // handle partial sends + while (written < buf_len) + { + ssize_t sent = ::send(socket_fd_, buf + written, remaining, 0); + + if (sent <= 0) + return false; + + written += sent; + remaining -= sent; + } + + return true; +} \ No newline at end of file diff --git a/src/ur/commander.cpp b/src/ur/commander.cpp index 333bc52..0fd358e 100644 --- a/src/ur/commander.cpp +++ b/src/ur/commander.cpp @@ -3,10 +3,21 @@ bool URCommander::write(std::string& s) { - size_t len = s.size(); - const uint8_t* data = reinterpret_cast(s.c_str()); - ssize_t res = stream_.send(data, len); - return res > 0 && static_cast(res) == len; + size_t len = s.size(); + const uint8_t* data = reinterpret_cast(s.c_str()); + size_t written; + return stream_.write(data, len, written); +} + +void URCommander::formatArray(std::ostringstream &out, std::array &values) +{ + std::string mod("["); + for(auto const& val : values) + { + out << mod << val; + mod = ","; + } + out << "]"; } bool URCommander::uploadProg(std::string &s) @@ -14,59 +25,110 @@ bool URCommander::uploadProg(std::string &s) return write(s); } -bool URCommander::speedj(std::array &speeds, double acceleration) -{ - std::ostringstream out; - out << std::fixed << std::setprecision(4); - out << "speedj(["; - std::string mod; - for(auto const& val : speeds) - { - out << mod << val; - mod = ","; - } - out << "]," << acceleration << ")\n"; - std::string s(out.str()); - return write(s); -} -bool URCommander::stopj(double a) -{ - -} - -bool URCommander::setAnalogOut(uint8_t pin, double value) -{ - std::ostringstream out; - out << "set_analog_out(" << (int)pin << "," << std::fixed << std::setprecision(4) << value << ")\n"; - std::string s(out.str()); - return write(s); -} - -bool URCommander::setDigitalOut(uint8_t pin, bool value) -{ - std::ostringstream out; - out << "set_digital_out(" << (int)pin << "," << (value ? "True" : "False") << ")\n"; - std::string s(out.str()); - return write(s); -} - bool URCommander::setToolVoltage(uint8_t voltage) { - -} + if(voltage != 0 || voltage != 12 || voltage != 24) + return false; + std::ostringstream out; + out << "set_tool_voltage(" << (int)voltage << ")\n"; + std::string s(out.str()); + return write(s); +} bool URCommander::setFlag(uint8_t pin, bool value) { - std::ostringstream out; - out << "set_flag(" << (int)pin << "," << (value ? "True" : "False") << ")\n"; - std::string s(out.str()); - return write(s); + std::ostringstream out; + out << "set_flag(" << (int)pin << "," << (value ? "True" : "False") << ")\n"; + std::string s(out.str()); + return write(s); } bool URCommander::setPayload(double value) { - std::ostringstream out; - out << "set_payload(" << std::fixed << std::setprecision(4) << value << ")\n"; - std::string s(out.str()); - return write(s); -} \ No newline at end of file + std::ostringstream out; + out << "set_payload(" << std::fixed << std::setprecision(5) << value << ")\n"; + std::string s(out.str()); + return write(s); +} + +bool URCommander::stopj(double a) +{ + std::ostringstream out; + out << "stopj(" << std::fixed << std::setprecision(5) << a << ")\n"; + std::string s(out.str()); + return write(s); +} + + +bool URCommander_V1_X::speedj(std::array &speeds, double acceleration) +{ + std::ostringstream out; + out << std::fixed << std::setprecision(5); + out << "speedj("; + formatArray(out, speeds); + out << "," << acceleration << "," << 0.02 << ")\n"; + std::string s(out.str()); + return write(s); +} +bool URCommander_V1_X::setAnalogOut(uint8_t pin, double value) +{ + std::ostringstream out; + out << "set_analog_out(" << (int)pin << "," << std::fixed << std::setprecision(4) << value << ")\n"; + std::string s(out.str()); + return write(s); +} + +bool URCommander_V1_X::setDigitalOut(uint8_t pin, bool value) +{ + std::ostringstream out; + out << "set_digital_out(" << (int)pin << "," << (value ? "True" : "False") << ")\n"; + std::string s(out.str()); + return write(s); +} + + +bool URCommander_V3_X::speedj(std::array &speeds, double acceleration) +{ + std::ostringstream out; + out << std::fixed << std::setprecision(5); + out << "speedj("; + formatArray(out, speeds); + out << "," << acceleration << ")\n"; + std::string s(out.str()); + return write(s); +} +bool URCommander_V3_X::setAnalogOut(uint8_t pin, double value) +{ + std::ostringstream out; + out << "set_standard_analog_out(" << (int)pin << "," << std::fixed << std::setprecision(5) << value << ")\n"; + std::string s(out.str()); + return write(s); +} + +bool URCommander_V3_X::setDigitalOut(uint8_t pin, bool value) +{ + std::ostringstream out; + std::string func; + + if(pin < 8) + { + func = "set_standard_digital_out"; + } + else if(pin < 16) + { + func = "set_configurable_digital_out"; + pin -= 8; + } + else if(pin < 18) + { + func = "set_tool_digital_out"; + pin -= 16; + } + else + return false; + + + out << func << "(" << (int)pin << "," << (value ? "True" : "False") << ")\n"; + std::string s(out.str()); + return write(s); +} diff --git a/src/ur/server.cpp b/src/ur/server.cpp index d99f1a9..6c99e13 100644 --- a/src/ur/server.cpp +++ b/src/ur/server.cpp @@ -5,47 +5,48 @@ #include "ur_modern_driver/ur/server.h" URServer::URServer(int port) + : port_(port) { - std::string service = std::to_string(port); - struct addrinfo hints, *result; - std::memset(&hints, 0, sizeof(hints)); - - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = AI_PASSIVE; - - if (getaddrinfo(nullptr, service.c_str(), &hints, &result) != 0) - { - LOG_ERROR("Failed to setup recieving server"); - return; - } - - // loop through the list of addresses untill we find one that's connectable - for (struct addrinfo* p = result; p != nullptr; p = p->ai_next) - { - socket_fd_ = socket(p->ai_family, p->ai_socktype, p->ai_protocol); - - if (socket_fd_ == -1) // socket error? - continue; - - if (bind(socket_fd_, p->ai_addr, p->ai_addrlen) != 0) - continue; - - // disable Nagle's algorithm to ensure we sent packets as fast as possible - int flag = 1; - setsockopt(socket_fd_, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); - LOG_INFO("Server awaiting connection"); - return; - } - - LOG_ERROR("Failed to setup recieving server"); - std::exit(EXIT_FAILURE); } -URStream URServer::accept() +std::string URServer::getIP() { + char buf[128]; + int res = ::gethostname(buf, sizeof(buf)); + return std::string(buf); +} + +bool URServer::bind() +{ + std::string empty; + bool res = TCPSocket::setup(empty, port_); + state_ = TCPSocket::getState(); + + if(!res) + return false; + + if(::listen(getSocketFD(), 1) < 0) + return false; + + return true; +} + +bool URServer::accept() +{ + if(state_ != SocketState::Connected || client_.getSocketFD() > 0) + return false; + struct sockaddr addr; socklen_t addr_len; - int client_fd = ::accept(socket_fd_, &addr, &addr_len); - return URStream(client_fd); + int client_fd = ::accept(getSocketFD(), &addr, &addr_len); + + if(client_fd <= 0) + return false; + + return client_.setSocketFD(client_fd); +} + +bool URServer::write(const uint8_t* buf, size_t buf_len, size_t &written) +{ + return client_.write(buf, buf_len, written); } \ No newline at end of file diff --git a/src/ur/stream.cpp b/src/ur/stream.cpp index 240b261..05617c5 100644 --- a/src/ur/stream.cpp +++ b/src/ur/stream.cpp @@ -6,139 +6,38 @@ #include "ur_modern_driver/log.h" #include "ur_modern_driver/ur/stream.h" -bool URStream::connect() +bool URStream::write(const uint8_t* buf, size_t buf_len, size_t &written) { - if (initialized_) - return false; - - LOG_INFO("Connecting to UR @ %s:%d", host_.c_str(), port_); - - // gethostbyname() is deprecated so use getadderinfo() as described in: - // http://www.beej.us/guide/bgnet/output/html/multipage/syscalls.html#getaddrinfo - - std::string service = std::to_string(port_); - struct addrinfo hints, *result; - std::memset(&hints, 0, sizeof(hints)); - - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = AI_PASSIVE; - - if (getaddrinfo(host_.c_str(), service.c_str(), &hints, &result) != 0) - { - LOG_ERROR("Failed to get host name"); - return false; - } - - // loop through the list of addresses untill we find one that's connectable - for (struct addrinfo* p = result; p != nullptr; p = p->ai_next) - { - socket_fd_ = socket(p->ai_family, p->ai_socktype, p->ai_protocol); - - if (socket_fd_ == -1) // socket error? - continue; - - if (::connect(socket_fd_, p->ai_addr, p->ai_addrlen) != 0) - { - if (stopping_) - break; - else - continue; // try next addrinfo if connect fails - } - - // disable Nagle's algorithm to ensure we sent packets as fast as possible - int flag = 1; - setsockopt(socket_fd_, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); - initialized_ = true; - LOG_INFO("Connection successfully established"); - break; - } - - freeaddrinfo(result); - if (!initialized_) - LOG_ERROR("Connection failed"); - - return initialized_; + std::lock_guard lock(write_mutex_); + return TCPSocket::write(buf, buf_len, written); } -void URStream::disconnect() +bool URStream::read(uint8_t* buf, size_t buf_len, size_t &total) { - if (!initialized_ || stopping_) - return; + std::lock_guard lock(read_mutex_); - LOG_INFO("Disconnecting from %s:%d", host_.c_str(), port_); - - stopping_ = true; - close(socket_fd_); - initialized_ = false; -} - -void URStream::reconnect() -{ - disconnect(); - stopping_ = false; - connect(); -} - -ssize_t URStream::send(const uint8_t* buf, size_t buf_len) -{ - if (!initialized_) - return -1; - if (stopping_) - return 0; - - std::lock_guard lock(send_mutex_); - - size_t total = 0; - size_t remaining = buf_len; - - // TODO: handle reconnect? - // handle partial sends - while (total < buf_len) - { - ssize_t sent = ::send(socket_fd_, buf + total, remaining, 0); - if (sent <= 0) - return stopping_ ? 0 : sent; - total += sent; - remaining -= sent; - } - - return total; -} - -ssize_t URStream::receive(uint8_t* buf, size_t buf_len) -{ - if (!initialized_) - return -1; - if (stopping_) - return 0; - - std::lock_guard lock(receive_mutex_); - - size_t remainder = sizeof(int32_t); - uint8_t* buf_pos = buf; bool initial = true; + uint8_t* buf_pos = buf; + size_t remainder = sizeof(int32_t); + size_t read = 0; - do + while(remainder > 0 && TCPSocket::read(buf_pos, remainder, read)) { - ssize_t read = recv(socket_fd_, buf_pos, remainder, 0); - if (read <= 0) // failed reading from socket - return stopping_ ? 0 : read; - if (initial) { remainder = be32toh(*(reinterpret_cast(buf))); if (remainder >= (buf_len - sizeof(int32_t))) { LOG_ERROR("Packet size %zd is larger than buffer %zu, discarding.", remainder, buf_len); - return -1; + return false; } initial = false; } + total += read; buf_pos += read; remainder -= read; - } while (remainder > 0); - - return buf_pos - buf; + } + + return remainder == 0; } \ No newline at end of file