1
0
mirror of https://gitlab.com/obbart/universal_robots_ros_driver.git synced 2026-04-10 10:00:48 +02:00

Major refactor

This commit is contained in:
Simon Rasmussen
2017-04-27 06:40:03 +02:00
parent 46f4e493cf
commit c59bfc78cc
22 changed files with 825 additions and 423 deletions

View File

@@ -172,6 +172,7 @@ set(${PROJECT_NAME}_SOURCES
src/ur/master_board.cpp src/ur/master_board.cpp
src/ur/rt_state.cpp src/ur/rt_state.cpp
src/ur/messages.cpp src/ur/messages.cpp
src/tcp_socket.cpp
src/ur_driver.cpp src/ur_driver.cpp
src/ur_realtime_communication.cpp src/ur_realtime_communication.cpp
src/ur_communication.cpp src/ur_communication.cpp

View File

@@ -23,6 +23,9 @@ public:
virtual void stopConsumer() virtual void stopConsumer()
{ {
} }
virtual void onTimeout()
{
}
virtual bool consume(shared_ptr<T> product) = 0; virtual bool consume(shared_ptr<T> product) = 0;
}; };
@@ -59,6 +62,13 @@ public:
con->stopConsumer(); con->stopConsumer();
} }
} }
virtual void onTimeout()
{
for(auto &con : consumers_)
{
con->onTimeout();
}
}
bool consume(shared_ptr<T> product) bool consume(shared_ptr<T> product)
{ {
@@ -93,6 +103,8 @@ template <typename T>
class Pipeline class Pipeline
{ {
private: private:
typedef std::chrono::high_resolution_clock Clock;
typedef Clock::time_point Time;
IProducer<T>& producer_; IProducer<T>& producer_;
IConsumer<T>& consumer_; IConsumer<T>& consumer_;
BlockingReaderWriterQueue<unique_ptr<T>> queue_; BlockingReaderWriterQueue<unique_ptr<T>> queue_;
@@ -129,6 +141,8 @@ private:
{ {
consumer_.setupConsumer(); consumer_.setupConsumer();
unique_ptr<T> product; unique_ptr<T> product;
Time last_pkg = Clock::now();
Time last_warn = last_pkg;
while (running_) while (running_)
{ {
// 16000us timeout was chosen because we should // 16000us timeout was chosen because we should
@@ -136,8 +150,18 @@ private:
// 8ms so double it for some error margin // 8ms so double it for some error margin
if (!queue_.wait_dequeue_timed(product, std::chrono::milliseconds(16))) if (!queue_.wait_dequeue_timed(product, std::chrono::milliseconds(16)))
{ {
Time now = Clock::now();
auto pkg_diff = now - last_pkg;
auto warn_diff = now - last_warn;
if(pkg_diff > std::chrono::seconds(1) && warn_diff > std::chrono::seconds(1))
{
last_warn = now;
consumer_.onTimeout();
}
continue; continue;
} }
last_pkg = Clock::now();
if (!consumer_.consume(std::move(product))) if (!consumer_.consume(std::move(product)))
break; break;
} }

View File

@@ -5,6 +5,7 @@
#include <condition_variable> #include <condition_variable>
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <thread>
#include <ros/ros.h> #include <ros/ros.h>
#include <actionlib/server/action_server.h> #include <actionlib/server/action_server.h>
#include <actionlib/server/server_goal_handle.h> #include <actionlib/server/server_goal_handle.h>
@@ -17,7 +18,7 @@
#include "ur_modern_driver/ros/trajectory_follower.h" #include "ur_modern_driver/ros/trajectory_follower.h"
class ActionServer : public URRTPacketConsumer, public Service class ActionServer : public Service //,public URRTPacketConsumer
{ {
private: private:
typedef control_msgs::FollowJointTrajectoryAction Action; typedef control_msgs::FollowJointTrajectoryAction Action;
@@ -35,9 +36,11 @@ private:
GoalHandle curr_gh_; GoalHandle curr_gh_;
std::atomic<bool> interrupt_traj_;
std::atomic<bool> has_goal_, running_; std::atomic<bool> has_goal_, running_;
std::mutex tj_mutex_; std::mutex tj_mutex_;
std::condition_variable tj_cv_; std::condition_variable tj_cv_;
std::thread tj_thread_;
TrajectoryFollower& follower_; TrajectoryFollower& follower_;
@@ -50,20 +53,16 @@ private:
bool validateTrajectory(GoalHandle& gh, Result& res); bool validateTrajectory(GoalHandle& gh, Result& res);
bool try_execute(GoalHandle& gh, Result& res); bool try_execute(GoalHandle& gh, Result& res);
void interruptGoal(GoalHandle& gh);
std::vector<size_t> reorderMap(std::vector<std::string> goal_joints); std::vector<size_t> reorderMap(std::vector<std::string> goal_joints);
double interp_cubic(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel);
void trajectoryThread(); void trajectoryThread();
template <typename U>
double toSec(U const& u)
{
return std::chrono::duration_cast<std::chrono::duration<double>>(u).count();
}
public: public:
ActionServer(TrajectoryFollower& follower, std::vector<std::string>& joint_names, double max_velocity); ActionServer(TrajectoryFollower& follower, std::vector<std::string>& joint_names, double max_velocity);
void start();
virtual void onRobotStateChange(RobotState state); virtual void onRobotStateChange(RobotState state);
}; };

View File

@@ -34,6 +34,7 @@ private:
std::map<std::string, HardwareInterface*> available_interfaces_; std::map<std::string, HardwareInterface*> available_interfaces_;
std::atomic<bool> service_enabled_; std::atomic<bool> service_enabled_;
std::atomic<uint32_t> service_cooldown_;
// helper functions to map interfaces // helper functions to map interfaces
template <typename T> template <typename T>
@@ -51,6 +52,7 @@ private:
void read(RTShared& state); void read(RTShared& state);
bool update(RTShared& state); bool update(RTShared& state);
bool write(); bool write();
void reset();
public: public:
ROSController(URCommander& commander, std::vector<std::string>& joint_names, double max_vel_change); ROSController(URCommander& commander, std::vector<std::string>& joint_names, double max_vel_change);

View File

@@ -13,6 +13,7 @@ public:
virtual bool write() = 0; virtual bool write() = 0;
virtual void start() {} virtual void start() {}
virtual void stop() {} virtual void stop() {}
virtual void reset() {}
}; };
using hardware_interface::JointHandle; using hardware_interface::JointHandle;
@@ -48,6 +49,7 @@ private:
public: public:
VelocityInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector<std::string> &joint_names, double max_vel_change); VelocityInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector<std::string> &joint_names, double max_vel_change);
virtual bool write(); virtual bool write();
virtual void reset();
typedef hardware_interface::VelocityJointInterface parent_type; typedef hardware_interface::VelocityJointInterface parent_type;
}; };

View File

@@ -37,6 +37,8 @@ private:
case ur_msgs::SetIO::Request::FUN_SET_FLAG: case ur_msgs::SetIO::Request::FUN_SET_FLAG:
res = commander_.setFlag(req.pin, flag); res = commander_.setFlag(req.pin, flag);
break; break;
default:
LOG_WARN("Invalid setIO function called (%d)", req.fun);
} }
return (resp.success = res); return (resp.success = res);

View File

@@ -5,21 +5,39 @@
#include <cstddef> #include <cstddef>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <thread>
#include <inttypes.h> #include <inttypes.h>
#include "ur_modern_driver/log.h"
#include "ur_modern_driver/ur/commander.h" #include "ur_modern_driver/ur/commander.h"
#include "ur_modern_driver/ur/server.h" #include "ur_modern_driver/ur/server.h"
#include "ur_modern_driver/ur/stream.h"
struct TrajectoryPoint
{
std::array<double, 6> positions;
std::array<double, 6> velocities;
std::chrono::microseconds time_from_start;
TrajectoryPoint()
{
}
TrajectoryPoint(std::array<double, 6> &pos, std::array<double, 6> &vel, std::chrono::microseconds tfs)
: positions(pos)
, velocities(vel)
, time_from_start(tfs)
{
}
};
class TrajectoryFollower class TrajectoryFollower
{ {
private: private:
const int32_t MULT_JOINTSTATE_ = 1000000;
double servoj_time_, servoj_lookahead_time_, servoj_gain_; double servoj_time_, servoj_lookahead_time_, servoj_gain_;
std::atomic<bool> running_; std::atomic<bool> running_;
std::array<double, 6> last_positions_; std::array<double, 6> last_positions_;
URCommander &commander_; URCommander &commander_;
URServer server_; URServer server_;
URStream stream_; int reverse_port_;
std::string program_; std::string program_;
template <typename T> template <typename T>
@@ -30,15 +48,16 @@ private:
return s; return s;
} }
std::string buildProgram();
bool execute(std::array<double, 6> &positions, bool keep_alive); bool execute(std::array<double, 6> &positions, bool keep_alive);
double interpolate(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel);
public: public:
TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3); TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3);
std::string buildProgram(bool version_3);
bool start(); bool start();
bool execute(std::array<double, 6> &positions); bool execute(std::array<double, 6> &positions);
bool execute(std::vector<TrajectoryPoint> &trajectory, std::atomic<bool> &interrupt);
void stop(); void stop();
void halt(); //maybe void interrupt();
}; };

View File

@@ -0,0 +1,45 @@
#pragma once
#include <netdb.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <mutex>
#include <atomic>
#include <string>
enum class SocketState
{
Invalid,
Connected,
Disconnected,
Closed
};
class TCPSocket
{
private:
std::atomic<int> socket_fd_;
std::atomic<SocketState> state_;
protected:
virtual bool open(int socket_fd, struct sockaddr *address, size_t address_len)
{
return false;
}
bool setup(std::string &host, int port);
void close();
public:
TCPSocket();
virtual ~TCPSocket();
SocketState getState() { return state_; }
int getSocketFD() { return socket_fd_; }
bool setSocketFD(int socket_fd);
bool read(uint8_t* buf, size_t buf_len, size_t &read);
bool write(const uint8_t* buf, size_t buf_len, size_t &written);
};

View File

@@ -11,18 +11,46 @@ private:
protected: protected:
bool write(std::string& s); bool write(std::string& s);
void formatArray(std::ostringstream &out, std::array<double, 6> &values);
public: public:
URCommander(URStream& stream) : stream_(stream) URCommander(URStream& stream) : stream_(stream)
{ {
} }
virtual bool uploadProg(std::string &s); virtual bool speedj(std::array<double, 6> &speeds, double acceleration) = 0;
virtual bool setDigitalOut(uint8_t pin, bool value) = 0;
virtual bool setAnalogOut(uint8_t pin, double value) = 0;
//shared
bool uploadProg(std::string &s);
bool stopj(double a = 10.0);
bool setToolVoltage(uint8_t voltage);
bool setFlag(uint8_t pin, bool value);
bool setPayload(double value);
};
class URCommander_V1_X : public URCommander
{
public:
URCommander_V1_X(URStream& stream) : URCommander(stream)
{
}
virtual bool speedj(std::array<double, 6> &speeds, double acceleration);
virtual bool setDigitalOut(uint8_t pin, bool value);
virtual bool setAnalogOut(uint8_t pin, double value);
};
class URCommander_V3_X : public URCommander
{
public:
URCommander_V3_X(URStream& stream) : URCommander(stream)
{
}
virtual bool speedj(std::array<double, 6> &speeds, double acceleration); virtual bool speedj(std::array<double, 6> &speeds, double acceleration);
virtual bool stopj(double a = 10.0);
virtual bool setDigitalOut(uint8_t pin, bool value); virtual bool setDigitalOut(uint8_t pin, bool value);
virtual bool setAnalogOut(uint8_t pin, double value); virtual bool setAnalogOut(uint8_t pin, double value);
virtual bool setToolVoltage(uint8_t voltage);
virtual bool setFlag(uint8_t pin, bool value);
virtual bool setPayload(double value);
}; };

View File

@@ -71,6 +71,19 @@ public:
prod.teardownProducer(); prod.teardownProducer();
} }
bool isVersion3()
{
return major_version_ == 3;
}
std::unique_ptr<URCommander> getCommander(URStream &stream)
{
if(major_version_ == 1)
return std::unique_ptr<URCommander>(new URCommander_V1_X(stream));
else
return std::unique_ptr<URCommander>(new URCommander_V3_X(stream));
}
std::unique_ptr<URParser<StatePacket>> getStateParser() std::unique_ptr<URParser<StatePacket>> getStateParser()
{ {
if (major_version_ == 1) if (major_version_ == 1)

View File

@@ -1,4 +1,5 @@
#pragma once #pragma once
#include <chrono>
#include "ur_modern_driver/pipeline.h" #include "ur_modern_driver/pipeline.h"
#include "ur_modern_driver/ur/parser.h" #include "ur_modern_driver/ur/parser.h"
#include "ur_modern_driver/ur/stream.h" #include "ur_modern_driver/ur/stream.h"
@@ -9,9 +10,10 @@ class URProducer : public IProducer<T>
private: private:
URStream& stream_; URStream& stream_;
URParser<T>& parser_; URParser<T>& parser_;
std::chrono::seconds timeout_;
public: public:
URProducer(URStream& stream, URParser<T>& parser) : stream_(stream), parser_(parser) URProducer(URStream& stream, URParser<T>& parser) : stream_(stream), parser_(parser), timeout_(1)
{ {
} }
@@ -32,24 +34,29 @@ public:
{ {
// 4KB should be enough to hold any packet received from UR // 4KB should be enough to hold any packet received from UR
uint8_t buf[4096]; uint8_t buf[4096];
size_t read = 0;
// blocking call //expoential backoff reconnects
ssize_t len = stream_.receive(buf, sizeof(buf)); while(true)
// LOG_DEBUG("Read %d bytes from stream", len);
if (len == 0)
{ {
LOG_WARN("Read nothing from stream"); if(stream_.read(buf, sizeof(buf), read))
return false;
}
else if (len < 0)
{ {
LOG_WARN("Stream closed"); //reset sleep amount
return false; timeout_ = std::chrono::seconds(1);
break;
} }
BinParser bp(buf, static_cast<size_t>(len)); if(stream_.closed())
return false;
LOG_WARN("Failed to read from stream, reconnecting in %ld seconds...", timeout_.count());
std::this_thread::sleep_for(timeout_);
auto next = timeout_ * 2;
if(next <= std::chrono::seconds(120))
timeout_ = next;
}
BinParser bp(buf, read);
return parser_.parse(bp, products); return parser_.parse(bp, products);
} }
}; };

View File

@@ -6,14 +6,25 @@
#include <mutex> #include <mutex>
#include <atomic> #include <atomic>
#include <string> #include <string>
#include "ur_modern_driver/ur/stream.h" #include "ur_modern_driver/tcp_socket.h"
class URServer class URServer : private TCPSocket
{ {
private: private:
int socket_fd_ = -1; int port_;
SocketState state_;
TCPSocket client_;
protected:
virtual bool open(int socket_fd, struct sockaddr *address, size_t address_len)
{
return ::bind(socket_fd, address, address_len) == 0;
}
public: public:
URServer(int port); URServer(int port);
URStream accept(); std::string getIP();
bool bind();
bool accept();
bool write(const uint8_t* buf, size_t buf_len, size_t &written);
}; };

View File

@@ -5,56 +5,39 @@
#include <mutex> #include <mutex>
#include <atomic> #include <atomic>
#include <string> #include <string>
#include "ur_modern_driver/log.h"
#include "ur_modern_driver/tcp_socket.h"
/// Encapsulates a TCP socket class URStream : private TCPSocket
class URStream
{ {
private: private:
int socket_fd_ = -1;
std::string host_; std::string host_;
int port_; int port_;
std::mutex write_mutex_, read_mutex_;
std::atomic<bool> initialized_; protected:
std::atomic<bool> stopping_; virtual bool open(int socket_fd, struct sockaddr *address, size_t address_len)
std::mutex send_mutex_, receive_mutex_; {
return ::connect(socket_fd, address, address_len) == 0;
}
public: public:
URStream() URStream(std::string& host, int port) : host_(host), port_(port)
{ {
} }
URStream(std::string& host, int port) : host_(host), port_(port), initialized_(false), stopping_(false) bool connect()
{ {
return TCPSocket::setup(host_, port_);
}
void disconnect()
{
LOG_INFO("Disconnecting");
TCPSocket::close();
} }
URStream(int socket_fd) : socket_fd_(socket_fd), initialized_(true), stopping_(false) bool closed() { return getState() == SocketState::Closed; }
{
} bool read(uint8_t* buf, size_t buf_len, size_t &read);
bool write(const uint8_t* buf, size_t buf_len, size_t &written);
URStream(URStream&& other) noexcept : socket_fd_(other.socket_fd_), host_(other.host_), initialized_(other.initialized_.load()), stopping_(other.stopping_.load())
{
}
~URStream()
{
disconnect();
}
URStream& operator=(URStream&& other)
{
socket_fd_ = std::move(other.socket_fd_);
host_ = std::move(other.host_);
initialized_ = std::move(other.initialized_.load());
stopping_ = std::move(other.stopping_.load());
return *this;
}
bool connect();
void disconnect();
void reconnect();
ssize_t send(const uint8_t* buf, size_t buf_len);
ssize_t receive(uint8_t* buf, size_t buf_len);
}; };

View File

@@ -15,7 +15,14 @@ ActionServer::ActionServer(TrajectoryFollower& follower, std::vector<std::string
, state_(RobotState::Error) , state_(RobotState::Error)
, follower_(follower) , follower_(follower)
{ {
}
void ActionServer::start()
{
if(running_)
return;
running_ = true;
tj_thread_ = thread(&ActionServer::trajectoryThread, this);
} }
void ActionServer::onRobotStateChange(RobotState state) void ActionServer::onRobotStateChange(RobotState state)
@@ -34,7 +41,14 @@ void ActionServer::onGoal(GoalHandle gh)
void ActionServer::onCancel(GoalHandle gh) void ActionServer::onCancel(GoalHandle gh)
{ {
interrupt_traj_ = true;
//wait for goal to be interrupted
std::lock_guard<std::mutex> lock(tj_mutex_);
Result res;
res.error_code = -100;
res.error_string = "Goal cancelled by client";
gh.setCanceled(res);
} }
bool ActionServer::validate(GoalHandle& gh, Result& res) bool ActionServer::validate(GoalHandle& gh, Result& res)
@@ -125,9 +139,17 @@ bool ActionServer::validateTrajectory(GoalHandle& gh, Result& res)
} }
} }
//todo validate start position?
return true; return true;
} }
inline std::chrono::microseconds convert(const ros::Duration &dur)
{
return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::seconds(dur.sec))
+ std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::nanoseconds(dur.nsec));
}
bool ActionServer::try_execute(GoalHandle& gh, Result& res) bool ActionServer::try_execute(GoalHandle& gh, Result& res)
{ {
if(!running_) if(!running_)
@@ -137,27 +159,20 @@ bool ActionServer::try_execute(GoalHandle& gh, Result& res)
} }
if(!tj_mutex_.try_lock()) if(!tj_mutex_.try_lock())
{ {
has_goal_ = false; interrupt_traj_ = true;
//stop_trajectory();
res.error_string = "Received another trajectory"; res.error_string = "Received another trajectory";
curr_gh_.setAborted(res, res.error_string); curr_gh_.setAborted(res, res.error_string);
tj_mutex_.lock(); tj_mutex_.lock();
//todo: make configurable
std::this_thread::sleep_for(std::chrono::milliseconds(250));
} }
//locked here //locked here
curr_gh_ = gh; curr_gh_ = gh;
interrupt_traj_ = false;
has_goal_ = true; has_goal_ = true;
tj_mutex_.unlock(); tj_mutex_.unlock();
tj_cv_.notify_one(); tj_cv_.notify_one();
} return true;
inline double ActionServer::interp_cubic(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel)
{
using std::pow;
double a = p0_pos;
double b = p0_vel;
double c = (-3 * a + 3 * p1_pos - 2 * T * b - T * p1_vel) / pow(T, 2);
double d = (2 * a - 2 * p1_pos + T * b + T * p1_vel) / pow(T, 3);
return a + b * t + c * pow(t, 2) + d * pow(t, 3);
} }
std::vector<size_t> ActionServer::reorderMap(std::vector<std::string> goal_joints) std::vector<size_t> ActionServer::reorderMap(std::vector<std::string> goal_joints)
@@ -179,53 +194,56 @@ std::vector<size_t> ActionServer::reorderMap(std::vector<std::string> goal_joint
void ActionServer::trajectoryThread() void ActionServer::trajectoryThread()
{ {
follower_.start(); //todo check error
//as_.start();
while(running_) while(running_)
{ {
std::unique_lock<std::mutex> lk(tj_mutex_); std::unique_lock<std::mutex> lk(tj_mutex_);
if(!tj_cv_.wait_for(lk, std::chrono::milliseconds(100), [&]{return running_ && has_goal_;})) if(!tj_cv_.wait_for(lk, std::chrono::milliseconds(100), [&]{return running_ && has_goal_;}))
continue; continue;
auto g = curr_gh_.getGoal(); LOG_DEBUG("Trajectory received and accepted");
auto const& traj = g->trajectory; curr_gh_.setAccepted();
auto const& points = traj.points;
size_t len = points.size();
auto const& last_point = points[points.size() - 1];
double end_time = last_point.time_from_start.toSec();
auto mapping = reorderMap(traj.joint_names); auto goal = curr_gh_.getGoal();
std::chrono::high_resolution_clock::time_point t0, t; auto mapping = reorderMap(goal->trajectory.joint_names);
t = t0 = std::chrono::high_resolution_clock::now(); std::vector<TrajectoryPoint> trajectory(goal->trajectory.points.size());
size_t i = 0; for(auto const& point : goal->trajectory.points)
while(end_time >= toSec(t - t0) && has_goal_)
{ {
while(points[i].time_from_start.toSec() <= toSec(t - t0) && i < len) std::array<double, 6> pos, vel;
i++; for(size_t i = 0; i < 6; i++)
auto const& pp = points[i-1];
auto const& p = points[i];
auto pp_t = pp.time_from_start.toSec();
auto p_t =p.time_from_start.toSec();
std::array<double, 6> pos;
for(size_t j = 0; j < pos.size(); j++)
{ {
pos[i] = interp_cubic( //joint names of the goal might have a different ordering compared
toSec(t - t0) - pp_t, //to what URScript expects so need to map between the two
p_t - pp_t, size_t idx = mapping[i];
pp.positions[j], pos[idx] = point.positions[i];
p.positions[j], vel[idx] = point.velocities[i];
pp.velocities[j], }
p.velocities[j] trajectory.push_back(TrajectoryPoint(pos, vel, convert(point.time_from_start)));
);
} }
follower_.execute(pos); Result res;
//std::this_thread::sleep_for(std::chrono::milliseconds((int)((servoj_time_ * 1000) / 4.))); if(follower_.execute(trajectory, interrupt_traj_))
t = std::chrono::high_resolution_clock::now(); {
//interrupted goals must be handled by interrupt trigger
if(!interrupt_traj_)
{
LOG_DEBUG("Trajectory executed successfully");
res.error_code = Result::SUCCESSFUL;
curr_gh_.setSucceeded(res);
}
}
else
{
LOG_DEBUG("Trajectory failed");
res.error_code = -100;
res.error_string = "Connection to robot was lost";
curr_gh_.setAborted(res, res.error_string);
} }
has_goal_ = false; has_goal_ = false;
lk.unlock();
} }
follower_.stop();
} }

View File

@@ -20,6 +20,8 @@ void ROSController::setupConsumer()
void ROSController::doSwitch(const std::list<hardware_interface::ControllerInfo>& start_list, const std::list<hardware_interface::ControllerInfo>& stop_list) void ROSController::doSwitch(const std::list<hardware_interface::ControllerInfo>& start_list, const std::list<hardware_interface::ControllerInfo>& stop_list)
{ {
LOG_INFO("Switching hardware interface");
if (active_interface_ != nullptr && stop_list.size() > 0) if (active_interface_ != nullptr && stop_list.size() > 0)
{ {
LOG_INFO("Stopping active interface"); LOG_INFO("Stopping active interface");
@@ -54,6 +56,14 @@ bool ROSController::write()
return active_interface_->write(); return active_interface_->write();
} }
void ROSController::reset()
{
if (active_interface_ == nullptr)
return;
active_interface_->reset();
}
void ROSController::read(RTShared& packet) void ROSController::read(RTShared& packet)
{ {
joint_interface_.update(packet); joint_interface_.update(packet);
@@ -68,17 +78,32 @@ bool ROSController::update(RTShared& state)
lastUpdate_ = time; lastUpdate_ = time;
read(state); read(state);
controller_.update(time, diff); controller_.update(time, diff, !service_enabled_);
//emergency stop and such should not kill the pipeline //emergency stop and such should not kill the pipeline
//but still prevent writes //but still prevent writes
if(!service_enabled_) if(!service_enabled_)
{
reset();
return true; return true;
}
//allow the controller to update x times before allowing writes again
if(service_cooldown_ > 0)
{
service_cooldown_ -= 1;
return true;
}
return write(); return write();
} }
void ROSController::onRobotStateChange(RobotState state) void ROSController::onRobotStateChange(RobotState state)
{ {
service_enabled_ = (state == RobotState::Running); bool next = (state == RobotState::Running);
if(next == service_enabled_)
return;
service_enabled_ = next;
service_cooldown_ = 125;
} }

View File

@@ -1,4 +1,5 @@
#include "ur_modern_driver/ros/hardware_interface.h" #include "ur_modern_driver/ros/hardware_interface.h"
#include "ur_modern_driver/log.h"
JointInterface::JointInterface(std::vector<std::string> &joint_names) JointInterface::JointInterface(std::vector<std::string> &joint_names)
{ {
@@ -38,16 +39,23 @@ bool VelocityInterface::write()
{ {
for (size_t i = 0; i < 6; i++) for (size_t i = 0; i < 6; i++)
{ {
// clamp value to ±max_vel_change
double prev = prev_velocity_cmd_[i]; double prev = prev_velocity_cmd_[i];
double lo = prev - max_vel_change_; double lo = prev - max_vel_change_;
double hi = prev + max_vel_change_; double hi = prev + max_vel_change_;
// clamp value to ±max_vel_change
prev_velocity_cmd_[i] = std::max(lo, std::min(velocity_cmd_[i], hi)); prev_velocity_cmd_[i] = std::max(lo, std::min(velocity_cmd_[i], hi));
} }
return commander_.speedj(prev_velocity_cmd_, max_vel_change_); return commander_.speedj(prev_velocity_cmd_, max_vel_change_);
} }
void VelocityInterface::reset()
{
for(auto &val : prev_velocity_cmd_)
{
val = 0;
}
}
PositionInterface:: PositionInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector<std::string> &joint_names) PositionInterface:: PositionInterface(URCommander &commander, hardware_interface::JointStateInterface &js_interface, std::vector<std::string> &joint_names)
: commander_(commander) : commander_(commander)

View File

@@ -2,15 +2,11 @@
#include "ur_modern_driver/ros/trajectory_follower.h" #include "ur_modern_driver/ros/trajectory_follower.h"
TrajectoryFollower::TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3) static const int32_t MULT_JOINTSTATE_ = 1000000;
: running_(false)
, commander_(commander)
, server_(reverse_port)
, program_(buildProgram(version_3))
{
}
static const std::string JOINT_STATE_REPLACE("{{JOINT_STATE_REPLACE}}"); static const std::string JOINT_STATE_REPLACE("{{JOINT_STATE_REPLACE}}");
static const std::string SERVO_J_REPLACE("{{SERVO_J_REPLACE}}"); static const std::string SERVO_J_REPLACE("{{SERVO_J_REPLACE}}");
static const std::string SERVER_IP_REPLACE("{{SERVER_IP_REPLACE}}");
static const std::string SERVER_PORT_REPLACE("{{SERVER_PORT_REPLACE}}");
static const std::string POSITION_PROGRAM = R"( static const std::string POSITION_PROGRAM = R"(
def driverProg(): def driverProg():
MULT_jointstate = {{JOINT_STATE_REPLACE}} MULT_jointstate = {{JOINT_STATE_REPLACE}}
@@ -19,12 +15,14 @@ def driverProg():
SERVO_RUNNING = 1 SERVO_RUNNING = 1
cmd_servo_state = SERVO_IDLE cmd_servo_state = SERVO_IDLE
cmd_servo_q = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] cmd_servo_q = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
def set_servo_setpoint(q): def set_servo_setpoint(q):
enter_critical enter_critical
cmd_servo_state = SERVO_RUNNING cmd_servo_state = SERVO_RUNNING
cmd_servo_q = q cmd_servo_q = q
exit_critical exit_critical
end end
thread servoThread(): thread servoThread():
state = SERVO_IDLE state = SERVO_IDLE
while True: while True:
@@ -48,6 +46,8 @@ def driverProg():
end end
end end
socket_open(\"{{SERVER_IP_REPLACE}}\", {{SERVER_PORT_REPLACE}})
thread_servo = run servoThread() thread_servo = run servoThread()
keepalive = 1 keepalive = 1
while keepalive > 0: while keepalive > 0:
@@ -63,21 +63,34 @@ def driverProg():
kill thread_servo kill thread_servo
end end
)"; )";
std::string TrajectoryFollower::buildProgram(bool version_3)
TrajectoryFollower::TrajectoryFollower(URCommander &commander, int reverse_port, bool version_3)
: running_(false)
, commander_(commander)
, reverse_port_(reverse_port)
, server_(reverse_port)
{ {
std::string res(POSITION_PROGRAM); std::string res(POSITION_PROGRAM);
size_t js_idx = POSITION_PROGRAM.find(JOINT_STATE_REPLACE);
size_t sj_idx = POSITION_PROGRAM.find(SERVO_J_REPLACE);
res.replace(res.find(JOINT_STATE_REPLACE), JOINT_STATE_REPLACE.length(), std::to_string(MULT_JOINTSTATE_));
std::ostringstream out; std::ostringstream out;
out << "t=" << std::fixed << std::setprecision(4) << servoj_time_; out << "t=" << std::fixed << std::setprecision(4) << servoj_time_;
if(version_3) if(version_3)
out << ", lookahead_time=" << servoj_lookahead_time_ << ", gain=" << servoj_gain_; out << ", lookahead_time=" << servoj_lookahead_time_ << ", gain=" << servoj_gain_;
res.replace(js_idx, JOINT_STATE_REPLACE.length(), std::to_string(MULT_JOINTSTATE_)); res.replace(res.find(SERVO_J_REPLACE), SERVO_J_REPLACE.length(), out.str());
res.replace(sj_idx, SERVO_J_REPLACE.length(), out.str());
program_ = res;
}
std::string TrajectoryFollower::buildProgram()
{
std::string res(program_);
std::string IP(server_.getIP());
LOG_INFO("Local IP: %s ", IP.c_str());
res.replace(res.find(SERVER_IP_REPLACE), SERVER_IP_REPLACE.length(), "127.0.0.1");
res.replace(res.find(SERVER_PORT_REPLACE), SERVER_PORT_REPLACE.length(), std::to_string(reverse_port_));
return res; return res;
} }
@@ -86,12 +99,31 @@ bool TrajectoryFollower::start()
if(running_) if(running_)
return true; //not sure return true; //not sure
//TODO if(!server_.bind())
std::string prog(""); // buildProg(); {
if(!commander_.uploadProg(prog)) LOG_ERROR("Failed to bind server");
return false; return false;
}
stream_ = std::move(server_.accept()); //todo: pointer instead? LOG_INFO("Uploading trajectory program to robot");
std::string prog(buildProgram());
//std::string prog = "socket_open(\"127.0.0.1\", 50001)\n";
if(!commander_.uploadProg(prog))
{
LOG_ERROR("Program upload failed!");
return false;
}
LOG_INFO("Awaiting incomming robot connection");
if(!server_.accept())
{
LOG_ERROR("Failed to accept incomming robot connection");
return false;
}
LOG_INFO("Robot successfully connected");
return (running_ = true); return (running_ = true);
} }
@@ -115,8 +147,18 @@ bool TrajectoryFollower::execute(std::array<double, 6> &positions, bool keep_ali
int32_t val = htobe32(static_cast<int32_t>(keep_alive)); int32_t val = htobe32(static_cast<int32_t>(keep_alive));
append(idx, val); append(idx, val);
ssize_t res = stream_.send(buf, sizeof(buf)); size_t written;
return res > 0 && res == sizeof(buf); return server_.write(buf, sizeof(buf), written);
}
double TrajectoryFollower::interpolate(double t, double T, double p0_pos, double p1_pos, double p0_vel, double p1_vel)
{
using std::pow;
double a = p0_pos;
double b = p0_vel;
double c = (-3 * a + 3 * p1_pos - 2 * T * b - T * p1_vel) / pow(T, 2);
double d = (2 * a - 2 * p1_pos + T * b + T * p1_vel) / pow(T, 3);
return a + b * t + c * pow(t, 2) + d * pow(t, 3);
} }
bool TrajectoryFollower::execute(std::array<double, 6> &positions) bool TrajectoryFollower::execute(std::array<double, 6> &positions)
@@ -124,6 +166,68 @@ bool TrajectoryFollower::execute(std::array<double, 6> &positions)
return execute(positions, true); return execute(positions, true);
} }
bool TrajectoryFollower::execute(std::vector<TrajectoryPoint> &trajectory, std::atomic<bool> &interrupt)
{
if(!running_)
return false;
using namespace std::chrono;
typedef duration<double> double_seconds;
typedef high_resolution_clock Clock;
typedef Clock::time_point Time;
auto const& last = trajectory[trajectory.size()-1];
auto& prev = trajectory[0];
Time t0 = Clock::now();
Time latest = t0;
std::array<double, 6> positions;
for(auto const& point : trajectory)
{
//skip t0
if(&point == &prev)
continue;
auto duration = point.time_from_start - prev.time_from_start;
double d_s = duration_cast<double_seconds>(duration).count();
//interpolation loop
while(!interrupt)
{
latest = Clock::now();
auto elapsed = latest - t0;
if(point.time_from_start <= elapsed || last.time_from_start >= elapsed)
break;
double elapsed_s = duration_cast<double_seconds>(elapsed - prev.time_from_start).count();
//double prev_seconds
for(size_t j = 0; j < positions.size(); j++)
{
positions[j] = interpolate(
elapsed_s,
d_s,
prev.positions[j],
point.positions[j],
prev.velocities[j],
point.velocities[j]
);
}
if(!execute(positions, true))
return false;
std::this_thread::sleep_for(double_seconds(servoj_time_));
}
prev = point;
}
return true;
}
void TrajectoryFollower::stop() void TrajectoryFollower::stop()
{ {
if(!running_) if(!running_)
@@ -132,6 +236,6 @@ void TrajectoryFollower::stop()
std::array<double, 6> empty; std::array<double, 6> empty;
execute(empty, false); execute(empty, false);
stream_.disconnect(); //server_.disconnect();
running_ = false; running_ = false;
} }

View File

@@ -6,10 +6,12 @@
#include "ur_modern_driver/log.h" #include "ur_modern_driver/log.h"
#include "ur_modern_driver/pipeline.h" #include "ur_modern_driver/pipeline.h"
#include "ur_modern_driver/ros/action_server.h"
#include "ur_modern_driver/ros/controller.h"
#include "ur_modern_driver/ros/io_service.h" #include "ur_modern_driver/ros/io_service.h"
#include "ur_modern_driver/ros/mb_publisher.h" #include "ur_modern_driver/ros/mb_publisher.h"
#include "ur_modern_driver/ros/controller.h"
#include "ur_modern_driver/ros/rt_publisher.h" #include "ur_modern_driver/ros/rt_publisher.h"
#include "ur_modern_driver/ros/trajectory_follower.h"
#include "ur_modern_driver/ros/service_stopper.h" #include "ur_modern_driver/ros/service_stopper.h"
#include "ur_modern_driver/ur/commander.h" #include "ur_modern_driver/ur/commander.h"
#include "ur_modern_driver/ur/factory.h" #include "ur_modern_driver/ur/factory.h"
@@ -63,7 +65,7 @@ bool parse_args(ProgArgs &args)
return true; return true;
} }
#include "ur_modern_driver/event_counter.h" #include "ur_modern_driver/ur/server.h"
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
@@ -75,6 +77,7 @@ int main(int argc, char **argv)
return EXIT_FAILURE; return EXIT_FAILURE;
} }
URFactory factory(args.host); URFactory factory(args.host);
vector<Service*> services; vector<Service*> services;
@@ -84,17 +87,27 @@ int main(int argc, char **argv)
URStream rt_stream(args.host, UR_RT_PORT); URStream rt_stream(args.host, UR_RT_PORT);
URProducer<RTPacket> rt_prod(rt_stream, *rt_parser); URProducer<RTPacket> rt_prod(rt_stream, *rt_parser);
RTPublisher rt_pub(args.prefix, args.base_frame, args.tool_frame, args.use_ros_control); RTPublisher rt_pub(args.prefix, args.base_frame, args.tool_frame, args.use_ros_control);
URCommander rt_commander(rt_stream); auto rt_commander = factory.getCommander(rt_stream);
vector<IConsumer<RTPacket> *> rt_vec{&rt_pub}; vector<IConsumer<RTPacket> *> rt_vec{&rt_pub};
TrajectoryFollower traj_follower(*rt_commander, args.reverse_port, factory.isVersion3());
ROSController *controller(nullptr); ROSController *controller(nullptr);
ActionServer *action_server(nullptr);
if (args.use_ros_control) if (args.use_ros_control)
{ {
LOG_INFO("ROS control enabled"); LOG_INFO("ROS control enabled");
controller = new ROSController(rt_commander, args.joint_names, args.max_vel_change); controller = new ROSController(*rt_commander, args.joint_names, args.max_vel_change);
rt_vec.push_back(controller); rt_vec.push_back(controller);
services.push_back(controller); services.push_back(controller);
} }
else
{
LOG_INFO("ActionServer enabled");
action_server = new ActionServer(traj_follower, args.joint_names, args.max_velocity);
//rt_vec.push_back(action_server);
services.push_back(action_server);
}
MultiConsumer<RTPacket> rt_cons(rt_vec); MultiConsumer<RTPacket> rt_cons(rt_vec);
Pipeline<RTPacket> rt_pl(rt_prod, rt_cons); Pipeline<RTPacket> rt_pl(rt_prod, rt_cons);
@@ -116,8 +129,11 @@ int main(int argc, char **argv)
rt_pl.run(); rt_pl.run();
state_pl.run(); state_pl.run();
URCommander state_commander(state_stream); auto state_commander = factory.getCommander(state_stream);
IOService io_service(state_commander); IOService io_service(*state_commander);
if(action_server)
action_server->start();
ros::spin(); ros::spin();

133
src/tcp_socket.cpp Normal file
View File

@@ -0,0 +1,133 @@
#include <endian.h>
#include <netinet/tcp.h>
#include <unistd.h>
#include <cstring>
#include "ur_modern_driver/log.h"
#include "ur_modern_driver/tcp_socket.h"
TCPSocket::TCPSocket()
: socket_fd_(-1)
, state_(SocketState::Invalid)
{
}
TCPSocket::~TCPSocket()
{
close();
}
bool TCPSocket::setup(std::string &host, int port)
{
if(state_ == SocketState::Connected)
return false;
LOG_INFO("Setting up connection: %s:%d", host.c_str(), port);
// gethostbyname() is deprecated so use getadderinfo() as described in:
// http://www.beej.us/guide/bgnet/output/html/multipage/syscalls.html#getaddrinfo
const char *host_name = host.empty() ? nullptr : host.c_str();
std::string service = std::to_string(port);
struct addrinfo hints, *result;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
if (getaddrinfo(host_name, service.c_str(), &hints, &result) != 0)
{
LOG_ERROR("Failed to get address for %s:%d", host.c_str(), port);
return false;
}
bool connected = false;
// loop through the list of addresses untill we find one that's connectable
for (struct addrinfo* p = result; p != nullptr; p = p->ai_next)
{
socket_fd_ = ::socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (socket_fd_ != -1 && open(socket_fd_, p->ai_addr, p->ai_addrlen))
{
connected = true;
break;
}
}
freeaddrinfo(result);
if(!connected)
{
state_ = SocketState::Invalid;
LOG_ERROR("Connection setup failed for %s:%d", host.c_str(), port);
}
else
{
state_ = SocketState::Connected;
LOG_INFO("Connection established for %s:%d", host.c_str(), port);
}
return connected;
}
bool TCPSocket::setSocketFD(int socket_fd)
{
if(state_ == SocketState::Connected)
return false;
socket_fd_ = socket_fd;
state_ = SocketState::Connected;
return true;
}
void TCPSocket::close()
{
if(state_ != SocketState::Connected)
return;
state_ = SocketState::Closed;
::shutdown(socket_fd_, SHUT_RDWR);
socket_fd_ = -1;
}
bool TCPSocket::read(uint8_t *buf, size_t buf_len, size_t &read)
{
read = 0;
if(state_ != SocketState::Connected)
return false;
ssize_t res = ::recv(socket_fd_, buf, buf_len, 0);
if(res == 0)
{
state_ = SocketState::Disconnected;
return false;
}
else if(res < 0)
return false;
read = static_cast<size_t>(res);
return true;
}
bool TCPSocket::write(const uint8_t *buf, size_t buf_len, size_t &written)
{
written = 0;
if(state_ != SocketState::Connected)
return false;
size_t remaining = buf_len;
// handle partial sends
while (written < buf_len)
{
ssize_t sent = ::send(socket_fd_, buf + written, remaining, 0);
if (sent <= 0)
return false;
written += sent;
remaining -= sent;
}
return true;
}

View File

@@ -5,8 +5,19 @@ bool URCommander::write(std::string& s)
{ {
size_t len = s.size(); size_t len = s.size();
const uint8_t* data = reinterpret_cast<const uint8_t*>(s.c_str()); const uint8_t* data = reinterpret_cast<const uint8_t*>(s.c_str());
ssize_t res = stream_.send(data, len); size_t written;
return res > 0 && static_cast<size_t>(res) == len; return stream_.write(data, len, written);
}
void URCommander::formatArray(std::ostringstream &out, std::array<double, 6> &values)
{
std::string mod("[");
for(auto const& val : values)
{
out << mod << val;
mod = ",";
}
out << "]";
} }
bool URCommander::uploadProg(std::string &s) bool URCommander::uploadProg(std::string &s)
@@ -14,48 +25,17 @@ bool URCommander::uploadProg(std::string &s)
return write(s); return write(s);
} }
bool URCommander::speedj(std::array<double, 6> &speeds, double acceleration)
{
std::ostringstream out;
out << std::fixed << std::setprecision(4);
out << "speedj([";
std::string mod;
for(auto const& val : speeds)
{
out << mod << val;
mod = ",";
}
out << "]," << acceleration << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander::stopj(double a)
{
}
bool URCommander::setAnalogOut(uint8_t pin, double value)
{
std::ostringstream out;
out << "set_analog_out(" << (int)pin << "," << std::fixed << std::setprecision(4) << value << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander::setDigitalOut(uint8_t pin, bool value)
{
std::ostringstream out;
out << "set_digital_out(" << (int)pin << "," << (value ? "True" : "False") << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander::setToolVoltage(uint8_t voltage) bool URCommander::setToolVoltage(uint8_t voltage)
{ {
if(voltage != 0 || voltage != 12 || voltage != 24)
return false;
std::ostringstream out;
out << "set_tool_voltage(" << (int)voltage << ")\n";
std::string s(out.str());
return write(s);
} }
bool URCommander::setFlag(uint8_t pin, bool value) bool URCommander::setFlag(uint8_t pin, bool value)
{ {
std::ostringstream out; std::ostringstream out;
@@ -66,7 +46,89 @@ bool URCommander::setFlag(uint8_t pin, bool value)
bool URCommander::setPayload(double value) bool URCommander::setPayload(double value)
{ {
std::ostringstream out; std::ostringstream out;
out << "set_payload(" << std::fixed << std::setprecision(4) << value << ")\n"; out << "set_payload(" << std::fixed << std::setprecision(5) << value << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander::stopj(double a)
{
std::ostringstream out;
out << "stopj(" << std::fixed << std::setprecision(5) << a << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander_V1_X::speedj(std::array<double, 6> &speeds, double acceleration)
{
std::ostringstream out;
out << std::fixed << std::setprecision(5);
out << "speedj(";
formatArray(out, speeds);
out << "," << acceleration << "," << 0.02 << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander_V1_X::setAnalogOut(uint8_t pin, double value)
{
std::ostringstream out;
out << "set_analog_out(" << (int)pin << "," << std::fixed << std::setprecision(4) << value << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander_V1_X::setDigitalOut(uint8_t pin, bool value)
{
std::ostringstream out;
out << "set_digital_out(" << (int)pin << "," << (value ? "True" : "False") << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander_V3_X::speedj(std::array<double, 6> &speeds, double acceleration)
{
std::ostringstream out;
out << std::fixed << std::setprecision(5);
out << "speedj(";
formatArray(out, speeds);
out << "," << acceleration << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander_V3_X::setAnalogOut(uint8_t pin, double value)
{
std::ostringstream out;
out << "set_standard_analog_out(" << (int)pin << "," << std::fixed << std::setprecision(5) << value << ")\n";
std::string s(out.str());
return write(s);
}
bool URCommander_V3_X::setDigitalOut(uint8_t pin, bool value)
{
std::ostringstream out;
std::string func;
if(pin < 8)
{
func = "set_standard_digital_out";
}
else if(pin < 16)
{
func = "set_configurable_digital_out";
pin -= 8;
}
else if(pin < 18)
{
func = "set_tool_digital_out";
pin -= 16;
}
else
return false;
out << func << "(" << (int)pin << "," << (value ? "True" : "False") << ")\n";
std::string s(out.str()); std::string s(out.str());
return write(s); return write(s);
} }

View File

@@ -5,47 +5,48 @@
#include "ur_modern_driver/ur/server.h" #include "ur_modern_driver/ur/server.h"
URServer::URServer(int port) URServer::URServer(int port)
: port_(port)
{ {
std::string service = std::to_string(port);
struct addrinfo hints, *result;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
if (getaddrinfo(nullptr, service.c_str(), &hints, &result) != 0)
{
LOG_ERROR("Failed to setup recieving server");
return;
}
// loop through the list of addresses untill we find one that's connectable
for (struct addrinfo* p = result; p != nullptr; p = p->ai_next)
{
socket_fd_ = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (socket_fd_ == -1) // socket error?
continue;
if (bind(socket_fd_, p->ai_addr, p->ai_addrlen) != 0)
continue;
// disable Nagle's algorithm to ensure we sent packets as fast as possible
int flag = 1;
setsockopt(socket_fd_, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
LOG_INFO("Server awaiting connection");
return;
}
LOG_ERROR("Failed to setup recieving server");
std::exit(EXIT_FAILURE);
} }
URStream URServer::accept() std::string URServer::getIP()
{ {
char buf[128];
int res = ::gethostname(buf, sizeof(buf));
return std::string(buf);
}
bool URServer::bind()
{
std::string empty;
bool res = TCPSocket::setup(empty, port_);
state_ = TCPSocket::getState();
if(!res)
return false;
if(::listen(getSocketFD(), 1) < 0)
return false;
return true;
}
bool URServer::accept()
{
if(state_ != SocketState::Connected || client_.getSocketFD() > 0)
return false;
struct sockaddr addr; struct sockaddr addr;
socklen_t addr_len; socklen_t addr_len;
int client_fd = ::accept(socket_fd_, &addr, &addr_len); int client_fd = ::accept(getSocketFD(), &addr, &addr_len);
return URStream(client_fd);
if(client_fd <= 0)
return false;
return client_.setSocketFD(client_fd);
}
bool URServer::write(const uint8_t* buf, size_t buf_len, size_t &written)
{
return client_.write(buf, buf_len, written);
} }

View File

@@ -6,139 +6,38 @@
#include "ur_modern_driver/log.h" #include "ur_modern_driver/log.h"
#include "ur_modern_driver/ur/stream.h" #include "ur_modern_driver/ur/stream.h"
bool URStream::connect() bool URStream::write(const uint8_t* buf, size_t buf_len, size_t &written)
{ {
if (initialized_) std::lock_guard<std::mutex> lock(write_mutex_);
return false; return TCPSocket::write(buf, buf_len, written);
LOG_INFO("Connecting to UR @ %s:%d", host_.c_str(), port_);
// gethostbyname() is deprecated so use getadderinfo() as described in:
// http://www.beej.us/guide/bgnet/output/html/multipage/syscalls.html#getaddrinfo
std::string service = std::to_string(port_);
struct addrinfo hints, *result;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
if (getaddrinfo(host_.c_str(), service.c_str(), &hints, &result) != 0)
{
LOG_ERROR("Failed to get host name");
return false;
}
// loop through the list of addresses untill we find one that's connectable
for (struct addrinfo* p = result; p != nullptr; p = p->ai_next)
{
socket_fd_ = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (socket_fd_ == -1) // socket error?
continue;
if (::connect(socket_fd_, p->ai_addr, p->ai_addrlen) != 0)
{
if (stopping_)
break;
else
continue; // try next addrinfo if connect fails
}
// disable Nagle's algorithm to ensure we sent packets as fast as possible
int flag = 1;
setsockopt(socket_fd_, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
initialized_ = true;
LOG_INFO("Connection successfully established");
break;
}
freeaddrinfo(result);
if (!initialized_)
LOG_ERROR("Connection failed");
return initialized_;
} }
void URStream::disconnect() bool URStream::read(uint8_t* buf, size_t buf_len, size_t &total)
{ {
if (!initialized_ || stopping_) std::lock_guard<std::mutex> lock(read_mutex_);
return;
LOG_INFO("Disconnecting from %s:%d", host_.c_str(), port_);
stopping_ = true;
close(socket_fd_);
initialized_ = false;
}
void URStream::reconnect()
{
disconnect();
stopping_ = false;
connect();
}
ssize_t URStream::send(const uint8_t* buf, size_t buf_len)
{
if (!initialized_)
return -1;
if (stopping_)
return 0;
std::lock_guard<std::mutex> lock(send_mutex_);
size_t total = 0;
size_t remaining = buf_len;
// TODO: handle reconnect?
// handle partial sends
while (total < buf_len)
{
ssize_t sent = ::send(socket_fd_, buf + total, remaining, 0);
if (sent <= 0)
return stopping_ ? 0 : sent;
total += sent;
remaining -= sent;
}
return total;
}
ssize_t URStream::receive(uint8_t* buf, size_t buf_len)
{
if (!initialized_)
return -1;
if (stopping_)
return 0;
std::lock_guard<std::mutex> lock(receive_mutex_);
size_t remainder = sizeof(int32_t);
uint8_t* buf_pos = buf;
bool initial = true; bool initial = true;
uint8_t* buf_pos = buf;
size_t remainder = sizeof(int32_t);
size_t read = 0;
do while(remainder > 0 && TCPSocket::read(buf_pos, remainder, read))
{ {
ssize_t read = recv(socket_fd_, buf_pos, remainder, 0);
if (read <= 0) // failed reading from socket
return stopping_ ? 0 : read;
if (initial) if (initial)
{ {
remainder = be32toh(*(reinterpret_cast<int32_t*>(buf))); remainder = be32toh(*(reinterpret_cast<int32_t*>(buf)));
if (remainder >= (buf_len - sizeof(int32_t))) if (remainder >= (buf_len - sizeof(int32_t)))
{ {
LOG_ERROR("Packet size %zd is larger than buffer %zu, discarding.", remainder, buf_len); LOG_ERROR("Packet size %zd is larger than buffer %zu, discarding.", remainder, buf_len);
return -1; return false;
} }
initial = false; initial = false;
} }
total += read;
buf_pos += read; buf_pos += read;
remainder -= read; remainder -= read;
} while (remainder > 0); }
return buf_pos - buf; return remainder == 0;
} }