From d843a990a352c3eca538ea896052acb1c3a2631a Mon Sep 17 00:00:00 2001 From: tangger Date: Mon, 7 Apr 2025 19:45:34 +0800 Subject: [PATCH] modify code structure --- .gitignore | 3 +- .../__pycache__/agilex_robot.cpython-310.pyc | Bin 8786 -> 8676 bytes .../robot_components.cpython-310.pyc | Bin 0 -> 13660 bytes .../__pycache__/rosrobot.cpython-310.pyc | Bin 10577 -> 4605 bytes .../rosrobot_factory.cpython-310.pyc | Bin 0 -> 1935 bytes collect_data/agilex.yaml | 6 + collect_data/collect_data_lerobot.py | 37 +- collect_data/rosrobot_factory.py | 26 - init_robot.bash | 2 + lerobot | 1 + lerobot_aloha/README.MD | 3 + lerobot_aloha/collect_data_lerobot.py | 461 +++++++++++ .../__pycache__/agilex_robot.cpython-310.pyc | Bin 0 -> 8689 bytes .../robot_components.cpython-310.pyc | Bin 0 -> 13668 bytes .../__pycache__/rosrobot.cpython-310.pyc | Bin 0 -> 4618 bytes .../rosrobot_factory.cpython-310.pyc | Bin 0 -> 1948 bytes .../common}/agilex_robot.py | 33 +- .../common/robot_components.py | 571 +++++++------ lerobot_aloha/common/rosrobot.py | 136 ++++ lerobot_aloha/common/rosrobot_factory.py | 59 ++ lerobot_aloha/configs/agilex.yaml | 146 ++++ lerobot_aloha/inference.py | 769 ++++++++++++++++++ lerobot_aloha/read_parquet.py | 33 + lerobot_aloha/replay_data.py | 112 +++ lerobot_aloha/test.py | 70 ++ 25 files changed, 2135 insertions(+), 333 deletions(-) create mode 100644 collect_data/__pycache__/robot_components.cpython-310.pyc create mode 100644 collect_data/__pycache__/rosrobot_factory.cpython-310.pyc delete mode 100644 collect_data/rosrobot_factory.py create mode 100644 init_robot.bash create mode 160000 lerobot create mode 100644 lerobot_aloha/README.MD create mode 100644 lerobot_aloha/collect_data_lerobot.py create mode 100644 lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc create mode 100644 lerobot_aloha/common/__pycache__/robot_components.cpython-310.pyc create mode 100644 lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc create mode 100644 lerobot_aloha/common/__pycache__/rosrobot_factory.cpython-310.pyc rename {collect_data => lerobot_aloha/common}/agilex_robot.py (94%) rename collect_data/rosrobot.py => lerobot_aloha/common/robot_components.py (50%) create mode 100644 lerobot_aloha/common/rosrobot.py create mode 100644 lerobot_aloha/common/rosrobot_factory.py create mode 100644 lerobot_aloha/configs/agilex.yaml create mode 100644 lerobot_aloha/inference.py create mode 100644 lerobot_aloha/read_parquet.py create mode 100644 lerobot_aloha/replay_data.py create mode 100644 lerobot_aloha/test.py diff --git a/.gitignore b/.gitignore index 8e255c0..a587ac0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ cobot_magic/ -librealsense/ \ No newline at end of file +librealsense/ +data*/ \ No newline at end of file diff --git a/collect_data/__pycache__/agilex_robot.cpython-310.pyc b/collect_data/__pycache__/agilex_robot.cpython-310.pyc index 6eb27344c222b7a4f3861e95e5705c6075a879f6..b8b4a3391827d16b9f96b062e621c736a6ec35ef 100644 GIT binary patch delta 2900 zcma);No>?s7=Zn3kH@oRn=CVs4Z=1B5(+I#69Rz%VPBH67}Duv>`V-v@#NV93U!#4 zRF#UT6!cWJQn%z%QT527?Wu>JdZ>`9Ug`r<7paR>iK{u);z@NGM&)mCrVlC&Vp`!~ru>`I1CYs3mk|y!DtjYWx&;sxlEhQh+ zf_YU_^C2yi4{Kpkv5ROCkV{rHAJbw)xN~b|2eIn3I&K`-;)euX(SxrGdeEqylC*lA z3<@dr&ck&L%7Bs44Vof;_=s&5+(9?v8f=h!RTYBdFx~y|bB5R&@$djT$%&K1BmKhg z%1O}^J>jx)l{mZHAhtX)DNGSvy&}?nSJFeC_@1DLKai#f5g=1}P#90?5k1Og#rR5P zMw|uhCZQiVDS^FA$2@5!q{rBOFJuA$DU$@2s z|C6$-f|OKz^7aD%!_|`kFR;+kldlN$qq&@>dF;=D0kns?fqr=zgU{v+?t*HV-d~j9 zeHdGm91n3HCvC#x$$Its(xd`fWTut}NnH{rgFG>HeoEZxZ!a)YuA7CLK@v-Z&adDi zK2r}WprO~yik_m^CIv6pBaHhXlcxzi3IP?pu1A1=U@qLQCxrw3u8OyX@yb|W%x0Q6 z?Mp*Ws&2#OMo-n_@TfZWn&4gkRuIEhl9?vswP2^rgFPt^sixdj!Q(;EXgeGw}7A zIyeH4y1T~gywOz}DHYsO*RcyjU6w)Z5!+2?EPE`&&MF&76DuplGZ9dyQA7+;jX)O_ z#IaQmSc#1WfEFp1ibcasTgIrXMdv!y%#FE}fc$zR>kBrD*A+G#%(O4#&JXukrsHPq ze6hESJqUVR6JU0NCNaV?`aJx05dI|t#8+DdnqarpP0g#J>nplZbjO_EHbCbS;s>r~ ztdc=H*mvqsqhIYz6td|I&8NppM#*q!6=<>E(9XyTXiz?YKf znbgqH%JEgSEoAZrdl2gFUQ!o7l!G#iQV`dGJLd)<5mbpxT1AxoHU(RF)YxiAj!=ZM>6D@ID-9n932+V=vPM~W6kK*=-W6-M^*R%_rWGolNz1@SCmD`FesIRq{{eI9}Ph%Q)s?CwD9MC?MmfanJJ z0oSJ4F+V&?ZP*2+d~t&Ipm4V!NCNF;kLupsz6S^T5c?4L{qa@EJ*}HrSBqq_I04?F z2T&%N`5ZmSzKCxld)XiHWb@*)=0{w^GVEg7agCyLZ*~1cDbj(Ck06e+mPAX#32dH3 zoI>C+pabk=qIugoZ1Nf&hL#qg@VQL$1CplCpzI7Fm86HT&1Za;eV9m;`C1R*;5o#3 z#07+g;5EC5txJebfG=e;Zl&li;ux>;Ft%PsWDp~WEJ6pQ>M1Xkfdl=B9AXTin8>}X zJGr)IJ`R?PjY?t5JkI9?l^gUxmhjhRNqhH`B z2tUkxi6I5y<{Nt`4lj$+Vgc`#*f%ZPc3cHX#lbQv^ZqIpsKR&jakis%6-lrQt;fh9 z`?$5e%$JiFAGf71!6o7=S-UW5<|>1%X@G4^Z1FPWbYWnKyi? zU|-`mnlIyp$MAnT{2n-1CskNMIjsuBXER#$+85K3P5Z2_r&z*3n@vmbQJtTf4`0s?_r8FEE!o7C%=k)7BE|_*7FWeo z#5HLq#Z@&`xF%f9(@ot=nJF)Arai+n2>Q#e;$_T?kcpc$v!GMl4PMU730>07dj+!~ zWYz8S4j0X$S29aNp}8BqvRM|g?)G~F=75liTQMsl_@Fs>Sh7=g`if+yoq@};xydFc zrJ8a5i_v{a+nMV)H4^EE&K;yJyNR;;qGQ<(ttF!55zlHmQF_vEwb@B#F=2Yb_C1Hu z`Ro7rGN(f>Fgu6f7d|WVo9PPwf$W@OYFC=oY_%d?Cbo7q5k9%B(Qj+$7?bVPWdhq( zPQun~}1z-&_Cn+=ASdS9^97XC?Z6yuZ1tBKD`iB<^WKLi^r24u4;E z_ykPT^wwP)ytQ_7BE&6BKnuRzguUToc|idyvRDv7YOf_0l44`p%8`WEFI{=D*S91Q z2{y52GW-^nC@z-31T5^_QlhKa`HIv{j!Sdjz?^c8(B~kaVi(3G7$=v*?Y&nxu-_K1 zwmM!96O6mLM#A#ikP9l?I9u+jb{~8-TZ2(QT#N5%7&hL_H8Q7xpSB!+JA~KPJ1vB_ z|3Ca2!P43|Y-#QVSUL@^pR6tdjzMp%%^B!FL_KB<&bY2ch+QBoAkM~d zN$o=AYGuAH`86r?Uz6pm<@z&L-F2uxEC4RJ!x?iuiZhH+sI*8&X#66R{1!k|H#a3<-pEhQqG&W+r-SL8X`YcX9CrPqI5Av`2-p=EDf=(cwN4&t#6gL?!f*8pyuSt*a ztHs?>Q;9C3pG)IYrOWfV^$ zP9a7SNBPjeQ11^YifkPN$uuau_3Ab(G*hQLaPAdAtwR5ZvN*2^Qlzk!=R}#nq0{w- z-)S?4(pS;s&xpSwUPHW&cmpAlSV!p$VizEi8!uTbF2i3?De`WiWFs8JG@^-^0n|!V zB(H^v!-%s87vUi$xn9|pUx|ZRrL%lrWh*(t&s0WhJ}T~h=tRYa6|fF<>i$&d&~p}R z`R$+iXO&Cj^o?DE56I770H+_}CrRHe{sT87Ntdo3pr~FK^LI~(;XMCnc=yv6VJ2Q+ z9g~&yVh=PC?feo~Mz)Y4K00!=hiQ(~Q%*bdX{6c?v!LDNM$n0pal36tH&H6_2$JEs z5vdKoJ>6=?jed7|BBkNBg1A0ubj5!&qTQn;^}}F3Qchv9<+gtwo?3^o`uKc`ROaJj zlio>olny#mpqg=L7-`FL+!0f_CRVoA7Z>m}-?RAundBEXKcvRh} zPTQ_q4XnyRNJoe@)pI(@$UI;Du!pS)w_PW2ok+1;4HkD|50?rM=MmQs*mbLx7lHaw zmj&RPA>07ar9raejDEiSs?1 oKFq;8BAzDF$35TfxXylTo^UHbmvnObJ3ulZXW*L`M14vAAH*Q2g8%>k diff --git a/collect_data/__pycache__/robot_components.cpython-310.pyc b/collect_data/__pycache__/robot_components.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f08b674b1d988609a8ae9cc5ada5cc9c779115c GIT binary patch literal 13660 zcmcgzU635tRqns(>FJrB{ndI`lEv6o5{+Z89NS4iksT?rXn5bG~!0`V$kbhTj)X{>N(PK~4KtdddDwyqv=0e;0|+ zgx=9=^w;RJd%kh4Z2&oU2A{8gB*R;?2d|4Bm>OgtrpjX7M&5%6Kc|Z4Pgf zVhV3lcsn5Ub6R!!2Ij0SU)Q6Y@GcI#svbFKE`3V2#j01eBI}`cGldT6x zmX&6|yV38V%#Vsq?pANLJ6_Yk?2{kmiTh{p_zOt3v>k0r-_^IYI-VV4R|D6G&PMoE zw-Wz8(`yIqMyDNm>8EH`$!uD(->p1z`mskV^u4rKnmH|3{S)aMjhUI=iOLg=V677L z)d;F49XUJRa@CC7dZ*tI_02}N6WOY5G*!QZxe0Kyu1dgZ%lA4hIe`UU*W?WT{P>f{ z*TClEgOx!q7#x4}ne)dx9@DMf==9eb$D92QXcg2kPUE<$7!q`81L9P#w?T(`y-c*T zkd*QzJ)Fz1X(?$E+0qLJvNu9|qPv5KY5()c1faaouNq&S*fs)v$JjD<^`|uh`_R%X z?W(D3=d_P#{(;N74zB3hGsGJMyfMSu9t9aHp!2cNR;6vE+HdyVXPY4 z`j-BX_H20@OWe}dwM)933v^*_8Lw!s8CUfe^ozEzu&`Ws$5Syc#}-Gt=y-wG^L#(b zHyd3~HvGtJ$ZmMsV{PAW_f{(s`k*bngf0~y{KRfKQCY4UauTzNT#P>O9+t8%V-nP+ zlV0_>IfjpO<7An_5Vzv-OGq4$#nsEY@vil5czA?3W8xTD4QzlBQOIEcQqfGOE69)_ z3^Zxe)?=}Np1Epr&P;Pwn#-j*^oZL*!N$1)<%Iicp6I=!?`oj^#IG6K=C-wkF$TJP zV9Np@j4kt|cJUtG?#1x8(;fPlCs(o4u~FJxuY$=BNYnTgD9!#TThcCRX7667#h<7= zxo0%i8el7QnW%Kyy`kb*j?C7EA5C-{&(&!Pe&FG?=xwz9zVPa8@mw@H@V&b5HiET! zn{~QF1uG5Ti*lFR!t0-nN>B!Y?04$H=7v|zM{Y|(YU2voz;?B=7UJEf%2;!3`HQLTs!c(KFd)xr?$_}p-HVA6RN(^zy~0c7raeOf&^_$fT9^l z?JB;cXi+E19K{g+5hPCD(p}SqOeO#9|Cw@0#*8c0JK>S6gzX(hei%bEY5OHV%BNeA zHZ*O#9c)Bs8PXwoc=sb78I7za=UG=Sx1o@Q*J=zpurG~HXQk1+kZd6n^hnePeKg9) zEyorsY5SkV;2+Ej+k;0SaiegCv*Y(S91TjLD-0Z%@D z4h-#f2fcQa9R-%w5zRx)pyv$Do8%4xq1ClOzX}1V8*NHWPN6CAH6vV2_jbk&pf5Gl z6w>N6vsVWwDT5Y>9hHpCrM;H1t4K*plG?~#VUI2<`hyj}DcdWa^dom;u+nM!Ysf~MCBSoE3Y+Jx-;_4JFinN`>`*>1+Vhbqp)pDi?c{oZD` zKk%!TA_>WCltX9uAqDnAsBJ=n2T+@lXcFdVf?Fk@djdsl%pIew&lurCh880YZ@dk? zaX!ig{f%}ra^YVOyt)s^FDeGj4W4?ZvDu^@m;7iK(8Cf%UPGX*vzpm5VP#;cUdnCT zTQPYF;s$J`p)xf0cE5VmpC4I&D$P}$SCFi{@I)f!iJ_r!IAR00M0p z>p6K$*gak3(bu9L7Z?dw1M$~VJl3_SWV(2>~4En~NdpYH1hHi#g^RbI@Be`t9&C!vh(P z4cp3(pfSYeQ}_eTA%78n#G?zau|O zh0BzjqhzR7hr0C_P(j@?MslvSke@4-3zF)Zct(D-El%O_zlCJnND9*)MDNw(ba+R_GT=4Uhc75p;vfqyXW!0d4{WlJEjikoJ<`! zWJ;kOHq1)L)G(6#RZ3`HP%*zw`8txwgiL0(FD*uX4#}7C_+-6o%_jW%j%`SCY-Bpt zaXHv{D_f$KeNd z-uOeKgl84?9D4ePk=Tx{J9^oi!DGXh1M+^SoVWEU9mpGc5wWj$qo%x2$YO-<6s90L zkqTChuePI6v03nPW#AJlNs09=sfO;X7^OCV^Qr$XYG-SSCGfooG^70fJa`Zfv_RM8^ zLJehl=98Wm5&wjmDTIqLc@^C<&Y)i_lZXu^{Vc_J+rGm4wK(`wmn&G-c~}jO0xVY@ zRt+qTl4Xjm8Ld4!Q13P_AUa51egl6LqLpld>^f$!>f2$ zh-R>hNG1(%X69zu9_E4C|1139sC2%Upp0;g1U2;~X#)x<=yi&kY4Q+q6euJf!Z7TD zKMBQ|M>=?9p42Fo7RS@(V=9yMWoTS;S)r?li8~}U904?3hQmGFn@mSm0)&pE*G8?m`=a(m z{fkBmR1wCD&bGN_LdwkWLyr)|BFquXS;Bn;Hkv_SZdMTcU+aqz7SN!iB}s4k&m#$s zT+^j<%~;p2X^UFVLQba~fxj*NDm>x~Ie81&Sk@@|4z>^Zu}{MQokJKV$ZbBTYukW_ z;ap(v_bXA9xa!~=dB5nh2-&l}WB$j*!nj5GjXol#tst@+8;AjmsPOP}P3#p!M3viW!;eI7NvrQ9_b!5bxQw?zG8#&Y4rOmJbaD-^kkOabqwQXDd#{nm-S=1bH#W!ynmF* z`)$ObJ{^+M@UnSE{gJ$IO^YM&|!jN@c#&QDul`TpJlw3jr=bab1 zNd=>}KuVPT9UeyRWy1woLlr_vbpJkEaYT^G;H68u9WQN&Y;xoio4!dheQHr#6uZ#F zgBib18M;afuzessA98X)9fPRtR8K0Dt(#SKyoZyT>uQCOlN5|CGY(ZGwLhQ*)8?}b z*R)Mnw~X-M2uC?!ku;aTiKZj$BsoYS>sRo?+EPWqKM9M6I5B0Jcmsnw%?1@8BdaJdf-FX2>EZya3i8e$SUXhIZMjkxzz z(i@g<8uXnB4?Q9KO`Ky?u!j@1l&!7*McZ_xh;YY+~<9-u*$$!K8xNNv1N(<9LjKK zMIAm>wD7~DtWJUA(DfLC_B|FfnWy{}`qDF$(DdK)!+CsIHqatH(qKG7dXdh8aYm?s zG_nbhUZwHIv?F=7|A~i#fdqof=9E5f%p2h?@3ETqgS`~~BkNDjObl}Gi}Ar|JRYVA zhq7r%SYh5&u9ktvZi`VZtsw(e>%%^BA2<*lRuE|a9)*NqC1Vf6ep0^?hS3&u@UZs( z2VO9pdn#iu%nOFRTbX}`z=vZ9Y%!3%6E)TjK&`L=uH}k!FW6=D8m|Y-y2F;DLi#MM9JuF`S zJ|!DSSn}%fYm`O&3uo%MA5t71FUdckmUmJzM9(g@o>AzTIGs4U3OX-SLxFZp zS4c5N6Vu7Nx@ncZ!6hBRi`M2RfER%vffsW-c8V7vFMIH!{ydZ><3;2~@S=4ZQrm;Y@he(KpY~@vu+>@ zH4ca9egyF$21183wl&dysgsCXW^>9@qjMWCMM4QcLI*(Yv{>%v4PL<rc;F|4sM()V->7ruEB(u}8C z3KugRU^1uV2e=`!OY$tC(E8-TDV28(CE;||FXxt#1JAtQYzLd8PJ2=s*c8{SFD$t(k$9eh!~LCK#`!eYw_i|y+uJ4sAFicHxsV8zS2d$VLC6lRR) z%Qtb^M+L@)0FJz4KN0~@ay+EFtqPd%a+lW2hK8uVEdLZ~ zvX7ZG-=qd(1o>l>Jw*f|Z3%)f5rBe@j_+m`eD%5EKH}reshssZ5^TmVik%N-zn|MZk6$d_>5z z>GFB0Ok1UypB;e2b5HIN#jkYTNnLtQKY=sFAp}aArwiVv0}gT2=KeO8EdLBi620Ip zpO$7sg!eW1=VV?uJZBg9lF{Ng%TlEA7Fy$X^i-6mJB|H8u=m0dZQ=B=hIR;DBOyvu30++IO^=(`X-NChUvQ)^~yPAJ}q$RF-ZW#+;-vvun(IStFN@&Z+0=exRl!LQg zz@^aJx9z~?x8jz)h?8PrU(n<)@#y(28xd+}v>qje*K%sCZQ!e(e(^Z>FUf8GJ25iq zL}t`29@QNgHLl-h)EbXgmS2y*KZ&^&F}JVtD5YU3`cAwSuMR8A$z5(y<~icJ9IVz9 z!5x%>PVxXBi%f?9o{G!UjW$mCaMiWL;W!pwb%0q}Yh0o;lUA$UY|~YH1krp%<@~g7 zn%OJXF^^AF9&0?;?hd*YzDtYyvzO_6?H*lD?)SuSEE3yvyeIkg_}Ha+x}tHSa$XL+ zO1p(|WEY2%{-D|PJmHDEDra%{%vJO$QGtY|~B9E%$n;$94KYdk~R@pKa$N{-g<`>jEzvss;1hZ@#- zEIu-O{mYRZ^yyA_WIaiK!sGZ%RKzKwzcvU&|1u6paQ*>8^d#@_L#S6ivY^}@lz5aJ zqojod9!JOXHj>L|-=QKGu7wQvmU3j#F%zyGtEF=p!iAkq9oNppriy}&TEfUyjMyM8 zath;>U0nV(a7oW7;p&g2TV~0e(~nY4H*Pw1NzcAb=?A$DwZx(LKU*J+BLf+ev828L z$S$wr13-Mlw|(ygJw+8{(m041eNueCmyk%(k&>K; zTB+B#G%Tm{N%?U~_@WaBM2D-QS?#xo$T_Jl9OH-%D>;b7(Mo`&C4B)$E|e>q<&Tx` JFF!xi{XeHq3%CFP literal 0 HcmV?d00001 diff --git a/collect_data/__pycache__/rosrobot.cpython-310.pyc b/collect_data/__pycache__/rosrobot.cpython-310.pyc index 17b14c621dc04ee8462a852b0cdb57cb454eb37a..67ba2c53f3acc25791ce51ef7dbcbcb439cb02fd 100644 GIT binary patch literal 4605 zcmb7H+ix7z8K2wk&R%?XZjMQYgg_R8UD3Oh(paIil@cR42*kXMX7?Pg$DP^Po|%Dq zl}b_Ukf_ZK0U?Fb*r*{1QWH=m5*$b#`#;QUGi&>)s(7jT)bBg9J6^Ayx}DWI`k)*%W260(p>YI?_Zn10Afhek_%GR_Bk7VO>$0Qhij&e) zVmzMK(=aaE8AsJsC#z?joSt*?dfqAM1yT69KopXi7f5PO(z}RwLMWzJVQ*ogsD#Q( zR@o2b@#<`-zEblow`$sQ<+$N|>8;y>VjJe9FQCA2>m)p3A(j zo43Zxeqj38?&a;5Oy4|iR?P}!W#rcWrJ->IiU;rnD(E5+bcq}y5|xQe6`Gn!?-HgI zjyt7iU?dGA8LC2`O0x4QJxg;WNAmN+oT%r4E|4xx7l7_2J)G_Wx|j5Ex*O8CYJw$eMdI0EQ@&u;`f!;&*a(Wwio@nIBc}3q&_L2Sb zg1&={kOOolAl(H>pC-=$(jjOmbT^d&y|hb!b4lbNDZ;8@7_%gP5zv8Em zhN8?#FfZ+A*0ohX+ia)dU!8-MM((7rkY5x;0safk|1SLJV5Q(1FwSYsukpvF=npp^ zEiFaYmX^P`|BtWFM}MAQzWny;sSlUmyBVGPIJ)us==_z{Q}}IHMVLIMd?{_GxP=f_k-*Tvdt62so!1RT~OT;*beyn_YX&>U5Aba zB|sF69(&`}QJXTfJj1lzX>+vf+BPlw2C^Fk27GL!1{P}=oRxuKy?&^2YC$ZBz3>|p zRaphP9$V)~YicO}iq3{Cgq4eTqf1xf8;MSR zU6dG34dn{;i!$2|G#kQ>453qIVEg=takUc%hoQpaNlsb*9BO?*%2V%&Zb+9zzbD=iUc>2eAK1sOo+3soNI) zAIzgm??m_J+L-5RLCLnf>BsTj4=kTec*pK}WNpostWmRKfiHh_53D@CeK>kA41}4o z2^M8;e(a}U#%RL2D?WB}!IpftC47Nt0U$_AWI&|*L z(M-UZgSHcKPGaDKXT%vWKuZAamjWD)FQakyi)iW7c9dZshev45pKbyKx09JRJ=6CY zzb^I+0@JYSunin|5em1BopOCg8}9ZUHz04>Y0FGDySA1P`L)GDgxmOULPUi{dBg_O zF2uEk8=!0{M1r174%#mH+ETQ1BfjJ?239gX8Y7zkmv#^_prW(qu+p&VkQ@&D0{e?l z)!U)2ZM6Yvr_}IX8&~f{mp^^<%|nQn?YEm~UjWGByRaP0tTg1>f#wkO9mjs|X_{y{ z7H)}VH*ByDO+j^g2zshjmcL|#J+m@!}-R|#m2`E8&|$+{Pq2bVur^HJr33w`KvCW z{G#Tm8924IWVp?Db1a#BvtiQRH5JhzgE zb&^0SBst#8(6B%jMpGUj6)p?R@?r@=t)O(UhlucCNc|R~_Y7uz@H`;O1qu35|5L$X zgn+pgevO5%8W(P_C0Z+IKZ|~MbNRh%jdPb3NOS>3q@5CyCD~e&-suEiD zi|J5p%@5N-4f2&(pllH4aS?{;T8&vSRf@y&3pmiBzc{%B#ls{ko$43+MO7LU>jSao ztV^zzVG4Xt9SJj-%(;PIMxnJnq!X8bc~?_V;v&x^>i>G(w0sQ{L0dCzTf+=XiygZ* zZN5o0)7NZjdcO9;cQl839@v~Wk_av#Hs-KRG-KLv^EZ|gI2v|nzN@{39IdM6QAkIL zr{$DNuI8CC!y8ygv|#XoD)xr z5V9}{h(lI%`F!Iq*Q0wkAVUY8>oCgtffGTgY@41(J48w6rU?W{HAb6Q}b42za@S~lED!g=RCi7^Roq$zd%`8*43q}~42v5vI z71-4%^Jf`m%4{E2`?10+WITVxrLj|>z>ls7g zvmfy>aIf&j8N-e>jpG!CL6(OKpJX2TDODH}2L~l3C;qP`r>ObCzM-Ceh3{1C4wYHc zvBR`KTLUYLf$&44U|0&PpF*Yg#-Xh2IyJWnTIcbXBYz$82O)n4^2Z0^i3lIcBh>}IglGd^t32{t4oWuzuIg)HEwQJLkq*yT~E(+o(IUi7h*gU!)t?mfOl|n(<4u-5>qk>7V}S{_0E(ko?tlGVP?v$m)0Q zE*1|_olMEuyXQXl-h0kH=R4=3+}j(K@E1z_=lEaUCrSTCnf5;wnLT*yNhC~SazRSd zTPeszC9TL_nVMEnrWUkfC>;`IT0t*{)8V3#Hj0sSq!>*{MR};uQ;el!#olzU$m@l8 zu`k`{kCR9zz9+FTGfqj&m{ZdIOg<>3BFp%;G;&#XL(Do>wNkQcJTSSB=h?WG(p>Gq ze9m#zy`?GFc)a4|%cX3=)gH~;PD*h@4;QoJmK%MfoG&>Covb5zKhDZU%i&Ww8o2#W zM`jNmdq|cfOG?X3N-OL>rdTRdEzJrY(Fdh5jcF`&DwGb3a-GR2p-F3Lqt%LehqP#!HiBW_QwTpG)dXU6gc%hj@c+|H34HUEgF=K(zS9wgJ!yfiH@$TCs# zwBl3Bf+R{*|E>A&kpI@{Ek3PO-^4Ic)BkI{_DTKE-u>c}PwH=eviLXG{{8RH)qinn z@#3E@&;ESzrFZLRUa!CND_@!}sf-o$6>$D;+kMKEJrMaIW#n2lW?DHZENaT8yN^ZY)zMXIaJ@*j<$= zT72Bf@M_7)7p)BScN3X>Des_3u~IHsCC7GGdnGLEWHYwIt2w92E!#CjJ6^WkUO|mn zS+7Drp5?`AC8fHWZ576NFZPDVkzAJeZFv3mzHJkb|83P#e7Cyo(Wee=%NEKL*-XLW z#K3L2a-m@5oD6+&8$KXPhblN<8E-0?7>2QF$Y?k!8~%SxuEm7}3^kj%I=3*VQcgkB z<{z=%z$4hd3xosaGkH#Vp>IZU|HbE=x}l z>rs|ct7|Ys{p|eGr=NDS$d&mBDTs2#D!JNJwpef@c6Q7XOmKCN3n`V;2x6C9c}(n( zm#n4YJ|ye$*rXa_N?hK2qn7kn*~ZI`6^781)T&tFEC^a(RpQVH2ffuQSg$J9tp=19 zP@#ZA55FDsi%&&}V$s*aGuWwlc|m$mdb;;V%8WXrO=FCz%y&;~)6x-TTHPfb+b#}B zPjw@vQNMVm{_zK0@>xG~)f4LxR~@TBn2OmG84^d^v5@VtDtWuitW2JraQmvZm0?!J znaJdYUNxv-G;3RKXfn^N@_sjlU3GZ5ka4Ce79>78#xcSCu{D%yM#6^A@dk=&)@%|lbk;sds;9<`=0iEN>2b+a&m z_9h;i_P|gyIjTnG*1!HgiKv0RGOK-A8}wOw^TEu)A_RX#lS~T^n;;x%4$3Pd=P6Z0 zszRA*Z7OOtngzP2G5=Bh^SQ-qr=UqjurLjku0ASkn%iSnN9`QXk6PSzqm}AtA#YC* zoycH;c^pg2S0R^WZ%gP{w+U@OviIV#>6{sgF6%N5=Nf-)UHYYCVJbS%T*ShLKropp zwxAgwAo>2lhQI2Vs%_rZfG}x4e8p{jtwDB2OIMD$k6iVah~(i z0=6QfFomh7s8o4SniG1(FZ!d8YppLB|fhR3^5YMkkIg*$^ZIEm@0tcR#UsR4kDPLBW z|0xV}BfC13Nx>>{>Uvpmm7(pfRw60r6j(R9ivh%u&Ab008lBp&m1^7CCAQ1{4ectF zbJ>D@|4^&1&@1~pCQSm}oU~19M@%{neJRfAa@UJ5)w493r&#hdUvyKV3 z@3L*8v6ygd22gZbzQttqO0`k}pzWL@C*!)(v<{)%rtg@yW-y00An|WeLM?AO){|7Z zhmsT}!3;)HG4JrI(Bwjv+&+7%l!M_I&j7?%En9Gl-;Hi=zo_!I+^Zzq2(|9@0om*$ zzCuj(tH}6xb$lmHpi{7T5@mbw2;HAhlX5~OZ3s2+Wg?uE2b82jB?%eoVEz9gZ5=Ek zoj^8TC`l3kM>>Ktc0`$nHD}5Klp<4^b{_6g(+1#;Q!nzm{|*b=5WEyUovN*((6KWF`hu|AUvZZoqs#vbtBPqiTp%0K- zL@0ggE4ae zPwXk^q_K_&)!Sr~fRdpq9hd;GH|9TvVr*f`rPJ@#&%V3((wmJl7n@kyQ=35-PLeUQ zOw$7+^n8f{zykys&GtzbkBU(r?C_yM`L$~yvL=;o*$yH6(%s@)U)pMvlVj{u>2 zJ*C=y5O|WwC4m!gG_-Mvh2D}E6ukAfWEOr65b21>DGhXIu&TCz}0_g;+n1Q?g3NsV-c;#`Cj{*PdT|^&|L%i_d>t|Lj=- zGxq~m5`axd`1okXp2z~Ls>}Zh_ASpsz?Ny5Far97l^>sQ+|cnnb0z>_CbFfHRk*Ao z5<&RZ4T-L9h|Z&3TU8#lEk2ng_%bBMw}(8at5?C69pB`Cf^qqeD9KPlCt#UmrZ%`j zS5F-hIwVsp6lLr1@*kk=vgfs#lMVP!E?^Tqw zf&u{kk4`1nC8iycKtdi!!apM8NrDHAV67AfW)(mQd|v=U93v?R@_Cts1ou^@3%(Pi z3NOffk3UM7dci=`C~$H*MhWt5qZq*`5sYH_qeR3gsu)G>8YRfLjS}4j90QnyC)G0o z?$UX5!((vtoO^cy&6J%ow>jy3K8mE)-_AC#wJ~V{nB<+Q+CY0{v((kCv9U6DYRUFm zuk5nc;;*X`ZXzb9pugyYM<^*$Qlg~G&>cqED|iIPfjU(aK)JE&F&(thO((9y)mhhW z?m9J32ct>h28}_sJA8~b0rSYQ}jTa~hwv?BVZ7Q&1RMoZ6AENA+cm%F$ zE;MnyQ*gUfX2*Kyyafhi9WrDTNWh>{JfVbo!d$};0;5U*!c!#(Ce=TA$5$dW^OU?y zdy_T_`I438nGuKH^1q4IuI^M>zuLjmd+OL;gt!5en?Vd4<7iVfy(H%5)ZP zDpuse@I(yJcc0(aXzLrgsc)p&7fLxw$4NMbOIJQa0L@d!9>!UkeJ&j*OxQx&`zgSi zx4l5a@)A~c2);v9;KGcZuUT%7ldA~4Q^-!G!@ic{d$HoQ-qe~#o7Q{`HE#4^;8pVo z6|Bw_vq#}c3$GNfLSZzUJIc4A&V!~CM0JVZ<0$(FJT^ruk^uD@Sfl{7DELqQE^c(@ zp;P~>w#h$fw?bRDW@5WuC9BBPqY57aJ75n5TQ~S@)dXwqb~LaOOayCyZ|1?k1!!8h zo?*dKqh)K4#aZt`5|HLdJN%KrI0(oSzyCjHPM-{9L3l=_%H+!5tJF;#Ddp=ODw$RU{V^kt)7SZWd0~Bwiz_ zo|Yoz52M60o+JkR1SQniO?v#Np*ooFMxJPfT=;f0^kWXLK`qJByHU|c1i8EA;|X;5 z7d$q3ib*XAT;)lG;Y$dGnUV-Nb(R`oFAo{Iju!u2+tML3^t~3uw%cHMhwPAQ3sm(! zWWWK5Z(!t{5RW#AqYIi2K0`>~JN%^lSja#0zUm|O=c_=9=y?3{RA9iYbd(FvgESkj zp;ZTqs1z}8&o89-9dGS_53=s6R_{&*tf(2kk1oH#W0OWsAc{vvP7|A&_?>a1c6-N$ z31!>bv5s9LH5u#@Sye{_bm-PWPzqA`5QIsCGS$12IHFQW50S6|@3e^a8mT&lb9W@f zw=x}gH;ld-^*A;tS{OvL2&6w+rGWC%m8^x+&F#egZpHO(~dF>4B1P~gLBQ6-Jnc{G^! zI;~<;)#_cS@;o!ow$xvJW9jN#{q(1}fT@4@cKyOhFQ|2s9FQWQeY6P~*EXJ!xp780 zBAvlyS4jgU6E!0fmYu_?IU3@(0mXVZQt~`ZCvg@R6+n!dD6mVvmmCoY4CLqpGPU)NbVl}M$CsM^W}Y;Q9MsOkMs`S$HuW7y z5M&Ax6q?rq34$fRNrL2O5EB#<1gb-Vw7tA&^5-#?lp(}~KLVPcrR2wyoTEgjC;r^D1F;;wL7S&NYJQH`h2BFuz|qe~nXnj!pez|5I5xp`vJ!3zar}6VpfQ?POI~MbO2F`J>jdxrBz?;H;1!%$E0icf(fZL z>8`=_0@3C~cG5Dlj#;n(G|ipcO@w-E#E$HtAnq3sH~)sf=i#vVSoTD|SS^}VU^->< zIDIi+qU)}5iP`3;g|L8Sj^`)oQmk4D#x(I4WDbb?^se72jF!uVVH42UGV@~?m~Q_~ zyPC^c7PHvb%>B3~5moeUWHxP_VH*B%9n?LV}B6QH@2s0gYE5v z@78R+O!mIEzP+AMUFQ%xF~kOchID;YupF#VtlhT9s)fQ-D&g5d?GP64s-^OAS9i*Y zUFTfw0NqNT!e`tb_?z}b)nVo1B>{o35hc%qBs1$7TtXS1{Y?=?8UQh(f@M_zb485i z5*0-eOyQ2WLw7a0_{OEUw`x&_C_$l+;Z~MSc~-YeT@c*9ht$O$5b+2usue?tD=}aw zP1e*XE-^L}B`eo;-NRLWt~V~P_1YlC6~{ko_tavZ%Jlf!0e|p`;12uO6gI#m9+c!s zIi(0ZwRGXb`U@}DFP>}s_1pE2-T{)Ld&15uiP|5)sm>L$wr$yFv5NCS0MROf!zKX? zarwVBg`2=tBW35|W}U^vW$sEi-;D(QN|>Rz2{e>ipSO&l?v^XzZ($`^!A88k_~@D+ zy7}zA#S8Cu&;8C@&)t{*tttNJ7-z+KhW$1a^XGqoR<&Mlj(&Mcru%D(FCpNAkOqGV zWNI`0hK78|Ova67GQ~2h7Kn~z0LHQfuSVEoPMV&4&T;(00Y9*ajG&iY$>uCRh#K+3 z6i$Yi_fxWi5`s3I;4G)J!=I$&FeN{vgknqlaY~+`&4#|Rk$?5Z-w5Wf|l?KN*HdsXz#&e4B4F)Ogd3X^?v~g$K84W diff --git a/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc b/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc77f3f060f97ed3a2fef5ec4731239873e6c566 GIT binary patch literal 1935 zcma)7&2Jk;6rb5IukCdj2tkb$p* zAPG$c3F!e+(Ui!!Q~{y12&q6K%KtG}VkbFrmG*r-ufMOEBXQ1VID`$VtDkgJG=%N=y)xFxTC$MchjHRoULDRwzxw$G@SRKtqRM6h->j1w>PX>8AWF(q~s+I|$4tCa== zwE|Z!M4nG|Kk(w18;m9m7EbvpyyM9bx|namj0Zrploq;3ZTuJMU|_55HmqjmIv`YVi(}+?>9freUgDKl6sNXw3d>j>IFw#NnMhTvfEdWm4^InZ zdYFwp@!x;e{{6+ zL+6W~&aKVP{$A(yt>aG@Pm`mwyLIsRf05%?%9*PYD~`^^nn*N=-Z4%3{Jbh*xkC< zU0vzy-aUP0S|10$u65S$biaPkUHz;d9LH=*(v%yXY(d0hPZ&K8iW6MRgi-t?AVQo{haSb%V$wT&>WA6@X;y86gDk zcOTr?!P*|WlNNI2p$H&SrYiV~$H>xm5tvmrYa9RWRl4uo~V!lVCK zs-arg!Z~#izo5(yovD#@M-fjI2e|3Fp;x1>%X6+gz>&JgV#>2qaDH6(e9CO$a}vMQ#3NW1gvg5Q zfhN+hrYg!PR&p<5X=B5}>v<5x0d`U7FM;3&paH#^iB7;oA7$dtkBI Robot: - """ - 根据配置文件自动创建合适的机器人实例 - Args: - config_file: 配置文件路径 - args: 运行时参数 - """ - with open(config_file, 'r') as f: - config = yaml.safe_load(f) - - robot_type = config.get('robot_type', 'agilex') - - if robot_type == 'agilex': - return AgilexRobot(config_file, args) - # 可扩展其他机器人类型 - else: - raise ValueError(f"Unsupported robot type: {robot_type}") diff --git a/init_robot.bash b/init_robot.bash new file mode 100644 index 0000000..f77c8ca --- /dev/null +++ b/init_robot.bash @@ -0,0 +1,2 @@ +source ~/ros_noetic/devel_isolated/setup.bash +cd cobot_magic/remote_control-x86-can-v2 && ./tools/can.sh && ./tools/jgl_2follower.sh \ No newline at end of file diff --git a/lerobot b/lerobot new file mode 160000 index 0000000..1c873df --- /dev/null +++ b/lerobot @@ -0,0 +1 @@ +Subproject commit 1c873df5c0dd4dd9a81cbd90e07dd95a272ee3f7 diff --git a/lerobot_aloha/README.MD b/lerobot_aloha/README.MD new file mode 100644 index 0000000..9e4d14a --- /dev/null +++ b/lerobot_aloha/README.MD @@ -0,0 +1,3 @@ +python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data + +python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1 \ No newline at end of file diff --git a/lerobot_aloha/collect_data_lerobot.py b/lerobot_aloha/collect_data_lerobot.py new file mode 100644 index 0000000..8ee0a52 --- /dev/null +++ b/lerobot_aloha/collect_data_lerobot.py @@ -0,0 +1,461 @@ +import logging +import time +from dataclasses import asdict +from pprint import pformat +from pprint import pprint + +# from safetensors.torch import load_file, save_file +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.control_configs import ( + CalibrateControlConfig, + ControlPipelineConfig, + RecordControlConfig, + RemoteRobotConfig, + ReplayControlConfig, + TeleoperateControlConfig, +) +from lerobot.common.robot_devices.control_utils import ( + # init_keyboard_listener, + record_episode, + stop_recording, + is_headless +) +from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config +from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect +from lerobot.common.utils.utils import has_method, init_logging, log_say +from lerobot.common.utils.utils import get_safe_torch_device +from contextlib import nullcontext +from copy import copy +import torch +import rospy +import cv2 +from lerobot.configs import parser +from common.agilex_robot import AgilexRobot +from common.rosrobot_factory import RobotFactory + + +######################################################################################## +# Control modes +######################################################################################## + + +def predict_action(observation, policy, device, use_amp): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + +def control_loop( + robot, + control_time_s=None, + teleoperate=False, + display_cameras=False, + dataset: LeRobotDataset | None = None, + events=None, + policy = None, + fps: int | None = None, + single_task: str | None = None, +): + # TODO(rcadene): Add option to record logs + # if not robot.is_connected: + # robot.connect() + + if events is None: + events = {"exit_early": False} + + if control_time_s is None: + control_time_s = float("inf") + + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + if dataset is not None and fps is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + rate = rospy.Rate(fps) + print_flag = True + while timestamp < control_time_s and not rospy.is_shutdown(): + # print(timestamp < control_time_s) + # print(rospy.is_shutdown()) + start_loop_t = time.perf_counter() + + if teleoperate: + observation, action = robot.teleop_step() + if observation is None or action is None: + if print_flag: + print("sync data fail, retrying...\n") + print_flag = False + rate.sleep() + continue + else: + # pass + observation = robot.capture_observation() + if policy is not None: + pred_action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} + + if dataset is not None: + frame = {**observation, **action, "task": single_task} + dataset.add_frame(frame) + + # if display_cameras and not is_headless(): + # image_keys = [key for key in observation if "image" in key] + # for key in image_keys: + # if "depth" in key: + # pass + # else: + # cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + + # print(1) + # cv2.waitKey(1) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + + # 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整) + screen_width = 1920 + screen_height = 1080 + + # 计算窗口的排列方式 + num_images = len(image_keys) + max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640 + rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数 + columns = min(num_images, max_columns) # 实际使用的列数 + + # 遍历所有图像键并显示 + for idx, key in enumerate(image_keys): + if "depth" in key: + continue # 跳过深度图像 + + # 将图像从 RGB 转换为 BGR 格式 + image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + + # 创建窗口 + cv2.imshow(key, image) + + # 计算窗口位置 + window_width = 640 + window_height = 480 + row = idx // max_columns + col = idx % max_columns + x_position = col * window_width + y_position = row * window_height + + # 移动窗口到指定位置 + cv2.moveWindow(key, x_position, y_position) + + # 等待 1 毫秒以处理事件 + cv2.waitKey(1) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + # log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["record_start"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + events["record_start"] = False + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + elif key == keyboard.Key.up: + print("Up arrow pressed. Start data recording...") + events["record_start"] = True + + + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def stop_recording(robot, listener, display_cameras): + + if not is_headless(): + if listener is not None: + listener.stop() + + if display_cameras: + cv2.destroyAllWindows() + + +def record_episode( + robot, + dataset, + events, + episode_time_s, + display_cameras, + policy, + fps, + single_task, +): + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_cameras=display_cameras, + dataset=dataset, + events=events, + policy=policy, + fps=fps, + teleoperate=policy is None, + single_task=single_task, + ) + + +def record( + robot, + cfg +) -> LeRobotDataset: + # TODO(rcadene): Add option to record logs + if cfg.resume: + dataset = LeRobotDataset( + cfg.repo_id, + root=cfg.root, + ) + if len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.num_image_writer_processes, + num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + # sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) + else: + # Create empty dataset or load existing saved episodes + # sanity_check_dataset_name(cfg.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.root, + robot=None, + features=robot.features, + use_videos=cfg.video, + image_writer_processes=cfg.num_image_writer_processes, + image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + # policy = None + + # if not robot.is_connected: + # robot.connect() + + listener, events = init_keyboard_listener() + + # Execute a few seconds without recording to: + # 1. teleoperate the robot to move it in starting position if no policy provided, + # 2. give times to the robot devices to connect and start synchronizing, + # 3. place the cameras windows on screen + enable_teleoperation = policy is None + log_say("Warmup record", cfg.play_sounds) + print() + print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n") + # warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) + + # if has_method(robot, "teleop_safety_stop"): + # robot.teleop_safety_stop() + + recorded_episodes = 0 + while True: + if recorded_episodes >= cfg.num_episodes: + break + + # if events["record_start"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}") + record_episode( + robot=robot, + dataset=dataset, + events=events, + episode_time_s=cfg.episode_time_s, + display_cameras=cfg.display_cameras, + policy=policy, + fps=cfg.fps, + single_task=cfg.single_task, + ) + + # Execute a few seconds without recording to give time to manually reset the environment + # Current code logic doesn't allow to teleoperate during this time. + # TODO(rcadene): add an option to enable teleoperation during reset + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + pprint("Reset the environment, stop recording") + # reset_environment(robot, events, cfg.reset_time_s, cfg.fps) + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + pprint("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 + + if events["stop_recording"]: + break + + log_say("Stop recording", cfg.play_sounds, blocking=True) + stop_recording(robot, listener, cfg.display_cameras) + + if cfg.push_to_hub: + dataset.push_to_hub(tags=cfg.tags, private=cfg.private) + + log_say("Exiting", cfg.play_sounds) + return dataset + + +def replay( + robot: AgilexRobot, + cfg, +): + # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset + # TODO(rcadene): Add option to record logs + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) + actions = dataset.hf_dataset.select_columns("action") + + # if not robot.is_connected: + # robot.connect() + + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / cfg.fps - dt_s) + + dt_s = time.perf_counter() - start_episode_t + # log_control_info(robot, dt_s, fps=cfg.fps) + + +import argparse +def get_arguments(): + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" + args.root = "./data5" + args.episode = 0 # replay episode + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 100 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "move the bottle from the right to the scale right" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy = None + # args.teleoprate = True + args.control_type = "record" + # args.control_type = "replay" + return args + + + + +# @parser.wrap() +def control_robot(cfg): + # 使用工厂模式创建机器人实例 + robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg) + + if cfg.control_type == "record": + record(robot, cfg) + elif cfg.control_type == "replay": + replay(robot, cfg) + + +if __name__ == "__main__": + cfg = get_arguments() + control_robot(cfg) + # control_robot() + # 使用工厂模式创建机器人实例 + # robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + # print(robot.features.items()) + # print([key for key, ft in robot.features.items() if ft["dtype"] == "video"]) + # record(robot, cfg) + # capture = robot.capture_observation() + # import torch + # torch.save(capture, "test.pt") + # action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094, + # 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]], + # device='cpu') + # robot.send_action(action.squeeze(0)) + # print() \ No newline at end of file diff --git a/lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc b/lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84750e2179c4db3f4f04842c879dbd4196c0f312 GIT binary patch literal 8689 zcmbtZU2GiJb)Ns7o&6)16iNM98rg2FbrQ<4lO|SFTe9TFNgT?R)VP^g9SnD7m%G~C zS>BnYBx2T)O*t-FD^XF`Me8_S21y#BNRa|Tr4~lr$3Dd`eJCbSbNHIT#un-e8F#Y!ZV>0dhd$piYoQAKktWC5;XeD z;Ncjq;39}8SfVRnz@jI6il=&7P1-H&7HX0!x$+go%J3H5%c9L_=BRbFJmgkf72J7O zb2C>2tAN3P8-1yd;$~d~twpc2B47e?7+^lm9j`Z9-bHR?tT8X_u}+mtU;kxMpm%rz zS8!Gqgb;7Zm97Yp&=d#Ks&G2JAZa%?T3G|_R`INfAf`FzxUdl-|^PI|KhE0e=8OB z-M_y5?F#3mFJqM~!bzPGu1k!3#N2f; zmRnxux)e*Z55?l*b%9Nw{?3WR3x3->+?(%p!rtMhzjF3)%VXpRZKvfgI9r~oe!J~= z4m;cv*_>~-dnwKyt$Q8sVwXKOO*T=ekS^~LmJf}%tl7F3+BN31z1gEJzv{Gt$7YA! zmhY*efjQ=a6PO)8G{I~(+D_dwT_<$rb{s#;Ll-L$DH0hYLIX?4B$al7?4`mKh?VR1 zx?Pyk@@k=#PajyLz7X<>TAr#pouLV&+^jLbZ7yqO=yw~{U$KM+dj=PO6LoqM2Kh=C zL7|!Y-9MBTCA#N1<2^LZ{u0wa~W%uezP>v2nD8K!$1WK6aRt=pdRP zNyWR`-Q~&$P4BSpuUI`?KPCk7^^@Cb^`zcf^dQPz4`*X56rRVdfjj#c=r$mUVrn6}&$2zuI(j~h-L-j4& zCC&~vDs86Vd61cmE-K=FNfGtDF76Y@L`{_5)6~3_7p1$}dz!XKoRnrnT{7iKYP+k+ zMKLK&%2VhoDUv9O+AqtQNl_Ok@9vXySr^OV6rQH;meD*WmhYC;Q*f&meE-)%!1E=&(5A#SPk4cdUg%?j(oE)D5hAZM1H+5y`N5C z6FHJcdqxtl%*((guzB$!w`TyFj&%oq!al@LkxyqgHraP$t8HAv4dQPIeFc{a4skHX zam-=Fc8%MV-53q;@h`w^z&gM&K#QB(f}cP_c`%W1xVcqep>T4rE6S{jz)JZ@dtG<~ z(SGqCz$YVye-?Ko)p&d-jo_Q!Stqaz;nK}o^>C^=4J=c@KkH`kG=qEbY6d8W?9o@n z!EV>!R+WZU{ZC!!XK@+0a<~YYs;PB0tSf9cAF!_WzScdPwDyZpVXz0-PDGDaqwF9T z?rrXiip_g)-|voJmH4_chV=(F)#rxw=B9d*S>`goALWNUn`8o8Jr!l4^(MxCFx)%1 zH_BrC_tE(I3`_tU5%T)SVLMtz+9%PTL;F7u+jE=RccDFx_S?hu{7AdXu@~0kHHkeq zx+3WrGi+>KMW`o~X{4)cTaEycI^pP-!pY{rsMNe4|6gl9;Mw?Kq{U27B@Ys?=t70~RRkMu1601=2?roZr7$0j1$|X?chL^ft zthF61LUw0-v=4Au=-rtdty=Pc)EnWYJLS<@uU7LJAWgDBH}~N12;C|b_AxT;<3whO zJVfLZL>?yc2$4?``4kbtUyKkW`!tcwkcxeV>K-L>gve)!90iHh&}Y>Jf;c{KO|RX( z#D0SsA0zTOab$a)0655721>#RSr}YsgbT6M>Bg$Vn1lRK*AE&rsd1*(@}2OJhhu$x z-~fZ^U;zmOpT)VV->G7rI=JEt^MVBj^QndxIC5+?I-bK~@nW2DI`v+Q%FC58c8tt< zg2)LXCqZHzQ_45Dl^faV>^G@R21;a~qmn@wkH489&Qp148E;9g%Ony4b{u^yjte>5 z%ktMj#jPMOH?pY%bPB!;GNUCRZw5e5lRrGTQ)ERf1NafJD@%ZI@(8dauy}W$Og!{o zZ_%Z^HY3tMFGJsDP1+-lUNcGp01en;;4i(W$i^>qeR+C^&A7>P(7BKDb~N3=e`6y7 zoB)yygjInOSgzO;Se}l@2+}ex1(%9?q_PGghc92^U;?R&0rGmQBGPU0KtXRkR0mp^ zY3fm?nZ@036$-rlEP`~7!Xi>n8yxu&4kJ>)fxx(lqcFmei&StRG=ih(sst-s?F~86 z9O*N{iiBhROe8ZOI!Y1j!a)-$o=9OAd5ywrbkwNq1zw}_nk;H00*C474q|lqejy!G zkt~i17)2=p)O2%EK7m=VB2Q(vFtnnWTJbW%=NS8`Ta2V7f&CQHk$_93-_NCwrP9Z_ zbikIuxLbmba#V(nu@N2PTXc-4I{pSa%IuC?<~r^`2hvFWiBN4$L=(+P{3+p-l%|_A z2-j-zJ|LeU2_=RnojhLu38#y6sb`M3F`igR&i1EIokl_ZsE38qJOLXV8CID1QRF^ip4xEKyaoascFBb62%Sc!(W#q!N1IN^O_ChJhJ2DT z6nWZd1l|@Erztse(pzd&y*WQT;deS-HT2vk8S~i@^Yb;6A6d;tl1G8gqDaDNJr3-(q7KnaE7)U+5rTxQ^fKNk>4Uhr*!stkT}nKBq&n z&%OX+4Tm6(Saq zuM*)w;yEgvBSNVWoTZ&Gi85cNIv!Sjmr4$ic_LLJE)fr;GMPB$8dW?=WPwP7NRtS8 z(>T|p4+V)2whHqeVjt21G-LR_fZ1XjW{~e@br5ST<--&utd>CuizJ9}&u*1d7Ct>G zQ$Cril&wYLCWd4ZrFf#6tYmv+CI1`^0nL(P(4kt95??zh890#7$T&u$MAsfgx~$+} zu1Lu>DeIDf=Y2|%d+#wS%SDng;&T;s`M{{p-NpiT8rrK8$oj90!q5RXu>2d`)P9II4ef#OK|3W179oF%ehD%mh+Jc9VFd=`?|ot66B%` z|B@l&EPeqk{Mkhv9A$n`i12j*?yE?BTn)7))RrP6Dn_Llf#|fbD)+~tF=*8WW2vhj zUlID_VR=P}#_b85Cw2{{R#9$F*wdH`{W~J+XCh;;n`-T1lmp!p{Dp$&fRK zbG;`Tr_iu?Fw)qMLUFJk`VU0fD$ZcYScn4?^FS`oa9+ZM51q&vZ*fKm?l@we<>+YC!;Ro1*R1hg&LlTxAZ^wsXmx_snXmfs7w?)y%-Dc_Cc- zf>%8sU<4iW1=@ zbH^MSa8B?6u!DG_?OcS)zND(F(-i)R3rS(vH|3rF!IK4lT-anb%KNSlW z<$~?X!Psa7jSeDVr|NNcA8W}nqD*k;0M^UZZZB37Sbskm%0mWE)yl9S*09KGje2sR z*c>Keg{IZx+f~%8|+){{q`b)e6s~A{S3bxecct8{OVKsun!V zLnbXJfypM36k#gmB#@97vGoImbE>`&_~LGK*bCj+w# zjNRt*8(6$PS)LmqDlSxejG=EDpK|Mu*y;Gzj<@D|tC&}R0gO!F5?XSw0Wg4P z9gf3+1Qex_@73~|F);}#WdsyS97BM}%Xu+RVL~n=nB;NSkoVO@gfhfV>N86dQx0{VQU3wy-bs(xYWS>bC&fwxt5B_~Rr zSD-aUN5EDKmZd@7O6+YKhxk0!JC7+#+cP?|v(bR5g*kL*0& K6aDFFQvP4_f6XHR literal 0 HcmV?d00001 diff --git a/lerobot_aloha/common/__pycache__/robot_components.cpython-310.pyc b/lerobot_aloha/common/__pycache__/robot_components.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00c1081212d5e02631996565b970a853a3b0bbcc GIT binary patch literal 13668 zcmcgzU635tRqns(>FJrB{ndI`lEv7T6O9wE9NS4iksT>=Xn5bG~!0{D}!y!|#hH|6{fDpr-vRy$t_Myqv-l`~Zp2 zgx=9=^w;RS0afh4X?YoU2A{8gB*R;?2d|4Bm>OgtrpjX7M&5%6Kc|Z4Pgf zVhV3lcsn5Ub6R!!2Ij0SU)STD@GkbfsvbLME`3V2#j01eV(X!HGmOpCz0KHpY$I&@ zy+$Xt9%%<*)rfPS?KW1u*!`T}?uF;VM##M%6MokVgLE3TIoHIWkzT9m{D zT1xGq;FDsH- zR+@fy!|$Oih>K0`R&TXCUemzr(;wxD`)BY33rM!K9c@eB)wi@do*iRX1J{VoM)Wne zlKei~YlrPdryY6Or)XB`Y+BOqR-Qin=pz;SUe+tioR+J>iR_KW9Gc#V%Hxf2trGfb z1XYubJUrfV)r{SG$8U)GW~1AQZPhlOs$as~1h`pOCE&Cbc%7Dlq?9k| z(cBQbmeMA%ExoWWdm}s``a5`-2S1NY2o4DSs`0gnZ6nloj4fkVe@Zj3A1%$&u9~`b zPWy-!9Js9O;Et|6O*}HdBQv_~5s)JxXGXLUcEq|Q!dd=?*NZJCTp?(*JSJq-l+&01 z6DrnQOevMjQZcQNge0JyP%<3-=-bi!$T~A_EH`zuNq@9Nq6{x_2^paaW7XK!xAcd! zXUf}H;+D3qUDD-Ts0(w;_?GsHaaDg_zi0~!3(G}!JelxxY;nYkju(18F9_m%v(fcr zBZ$q0>_)de+75ztZ?z(!7uwQG=~4;2CheLNmF21-Co!AY#pr$S(^C2}Awg|A?Nv{j zWB4RDPL?SQaVwsngv0?^T)nItZ(DCihewDrCYG_)z(yDmg&YPTRn2s|f(!{lUz0X% zK^6<>o2w@0%q(YRxm=b*kE9(GZIUZcPPi}UiQYTZd$DwuqqG)_=~^7Kd9l6A>4d-pso{zT=8J)^Pq zp#9e9iAty48z_+F*lcYC@kF=rY@MbMgdSdt-bOp{g;#HjXXD9!;MIk<5w6wStlJ$b zSZM@aoV(N(o_{tjK`DgN@6^N14X>Jy-Ij#Z)S*c|c^gPBZ>QuABypKKt36Lc$UCX# z14yue@~a*UMkwQ4JM_8%%TYC_w$IO@Nv)g^s=m>{2OyIdyiH7k1Z_-!;u%QoD!!y> zQ76hA#Sp;}Bu?JaUDJh3rT^^znQ}?Sj4Re#(UD;Z+dGW>5Qb>d_Dg@1&$c3MXxex? z*oe?Fq(k)R?uR`x9>W@)XI;76hGG_8tI_Yk&NMoml}7VIx`j;8BUK;t(Kw&999yiU z?SB$aK(?oBM3ciSz#!KRc>p!A=*!hf(yl#XPaUeh-t|SlL-|s@ezD)^BsKYZUHDBD zn?We$96C$dNRqS^>nuq-ROXSy4u}heKpqOq7&q8xG(CAYRnxqN=1N9Fveu(rUBnX* zwH(dT%f><1%`M~?79`a*@r?ZFL#OZrYe?S3PnZ|B2hTv{VCx-GfUVDqTSN(E6P^P< zpwO}tAHWq0X*t@?h*`8PibXLe4!~l%v?(ZcXS)OD4HEmEwj_haiK>lg#?VHaSyGsEp!ZHC0e?uwvvD+xC`4B z8tW!%;0wM;E;Lcij%LS*MW0!&O-MheXJ4$%tiqnjb~DjERB;}DZK2Wgdz)RqA5<+x z5|Y_ChtBXr3haeY+mr?mpms>2NtojaZk2xSaTKvJcZ{+=V?+x>v>0i4<4x#|^KmZp zH`>kEg`eH`>H!?UxEMA!c%Z>3Qo0UEFtDV3_4Dd z=_Y@?K4|e2#tTSiScaje|Lp&@@?_A{Ew;3$+YK_^O_GfCM{0;3rU8lHfCgAU+g)D& z{K)zKosD zL3Mlc$`?*VR;6-xB{I;0BwMxU7byGk_uu^6Z@>Q4??lrWc=WeiwUwmiJN|0B*YaZr zW=-jmPoS^73rTFX;kc#-lLgp>0FLv`#SnG1vJybR!@1^{Gl-y6rM=2pQrVQuDD0do3Y{IRJtrg$z$j?#XG9~9I z8K~8PZv6#RP`4o?IagZ9&lSrBNp(#;BR|>}r|<;dMKW$Ah3Sz|gzAKOv&CUHj~2}1 zVd07*${gerWG;wX#iZwoDX)kg(_#juu_R{2+zT4uwC$Avv?m#c8{&YNN8c%N%b=8? zIr=S#Mbu31Q*%%pLd}fGL-LNqCAL)Oyk6kTAiDq20q*u&2^J1iW#w1`Mm`Le3H~Th z2Y0Hn=8lUJ8b7g=@9dv_5wuZ;m;e$GxdG=K=6Flrf&1UqcInL`gYyFIEbAfdDw*vZ z*MoXoZ==SR$Fb6+)lQ!FAGdV4C0r10SwPEzJRQ13?-yyLqHxb4E`zlfqn9(2JY>Tp zTkd3SS%WO|{!D%azajd-KNxpl_Ba{x5J)wmqR5+}T}vYS6H*K0y*%R(<_tNY^`icD z<04$%>&A_Lgf2q_$p*6Mcs)s7uN$$kbWd!t3owv)(iE|KQXr~C(DOm!&`+rSf)-VW zg(LGNmQHqj7^L9Ar9m%-7iUMNYhsLay6pY)bjwOd3T?}{q)G4{QwMZTrw$x4rO*x= zW~F0l7)kyrB{VOnm|v%S9Z76LCWp2!Ek=GG$ye|MWW8+7Cj9!AZAfx#WH#1vwWM4Q zvW@ZtCG0R0+2hGzv)2URT7_2*#~|S4DK$6ERcek&B@ax?P5KeEQXQ`+;Rkl!_(P+F zXBG7mE#PPlT9Pm+pW6V_v*waon0!3IB5jY6@8+D}yhP;b_NgK~6^}PIn z?{QsMKF#RL<7qY6pk441O_)4c$FTJhtyy8gbj^D@j4YezPlk+_&1!~WARFdSBti#u z1EISEgfMmxlM(u&7E*iw>V;I!i^i@=b?b=HkokVaI7vjy%Tr{?(eqT+GdI)|YG|lu zKJ9rCaZsq4LbRBWSJ7?A84PM=60w1#pQTuDJ5YGPmIQz5as|se539jZfaR*gs)40Z zvP`ixqqWBe>fOc#LS^AsW|QR_H2X;tog+M*t0%;cyT3Cex9X0HNb}cF=Q-lt%>$rXeiGl8pWr4Vrw zII*n4W8H(M&geg+uh^h`2L1jD51%IhJ(;Dq9mBX;$~jQYW&N1$Tru8`?jL3Hej9OM zIn!0K6aoA)#2-w6zNiH3Hz>!m=py$+ngvP3lwp;k+u!?(dtg3s{0xn93<<)n*k%b= zOKQudmI4Zr(N7YGKl&-!SLw8~pc(o6d`zRFFl1eqv78WpWs8zNC6|!EdFMrLQo*P# zkP>Bohli1S*>FMDNQF>R-M`OP91&zPc-hi!$4eU`n;!inrf-r=pIX!w#V+*d;E>;^ z3|*xK*xr|30696JjzQFRs;3po*3GIq-ot6lb+y9SNejl7nS?6R+8@(`Y4cfzYucu( zTSjzngrgj=NSn*wMAH#=k{o1^^{aScZKUGd_!8Gexy)1-ZIX57rt8V4FAo1(xB)6|nRMu@s!J z*Sz-XS{Uapw?(*yQ<1d>U~=cWi6c@9ddE5L8t2Hj4We0p1+nc*ERst+ez2tC-*z7u z@FO3S1lb){xz0qp^e@>E0#oC=KFMZcMYY3kO``efL)N8#G$Yq4SA`tUp^rY;LvRK{k@tP!r7!Q-$2N7!DhQ4_AoE3f!}sIpUSj#>_;IC^7=-;kjjC9(o0D z96=V=1yfdVKfq>iZzhl(X9aKq^81y-d;WF?y@!-Fn}bR=&TWY&_@i1QjFf-(~jiT{wE#|1`-G^n^XF{F>ge-yu)hV3-(g@kE}mAGYQDOFTn?+@pzaf9Li=P zVTE}!xmpGuyDdSrw1y$DS|9Y0`@n(tu!2DQ_b8+cD;awj_HO+~7)D#r!Nc1BA9%rZ z?x~EuFfSPLZe{)%0v}Ewu*E?3F4R~%0JXvbtU^3(2e5TlXJ}=uJAk_h{JqGRb%w^d zCmF|P;a&b&#}ODG!oI@o3+Z<$Mz?OQyF=f*Hv!FeChP*)htS+OyLjWt-^1eN?^CjY zge9*ozfM`izi_6G`ys`_@sj)lYIzqW1N7`t>luZfiPNd0tDy5TH9mG! z)_k@+F4ixh1Fnl{0mZ3jtQo|16l!*WpkeUensJ6poSdqic?)&_9o;i7`bse9bNdUG z{1!el;vbNNq)$M5j-VsVi~^3hGQq&{Xa*cH_Q1n`7t}~Pfe1zlD4^(7bSTuW=?W<( zXkt2jS2wP*H@KuDc+uMY1n?pdB=BNx$IkE~KW9l-%R1h6u z@bQPp55R}KOs{m>CTYG32^x~INQ45o?nJ%)fM1n3NJnP|PAJOy`~dLXH_jo3d@K(M zW*sQqOrD4hi4>_%eulow3h@=nu~Ppgav!Acl8G%FIHtSpJ-;=g8;C>1dDabNp~m45 z-H#wXBtYo!kZn!1U*;s@w%MHW)acyC%aBkCkkA28J1ds^d4pFlhkI#7q_GI_Fx*kM ziB6}E$I^QMo*epSfamY=0`0*XkIoIY;bb62fonkI0q-K((Bg|%>3drr_{2%S{H`t^ zO!L%d8y6ojhMD$B#>n2YF&qvI(zgp3gR~mwc?_IhU<_+)jO;xd!-a2Ngf!!+mZHTW z4ltQB@&nuu*(G_FP-s4Ra7yJ}LrF9}?3Z)P*nwy6H{0RnsMDU7#Yx_h%^BNS$g=gT*7 z*+&J&1_H33nC($%nFx?ZeqSP>gYGkhTOtmf?U?1`E=G4yugZ_xVP3A@+AW|RWxQBQn zX+A@99B}S;QSlQzY$X!hGwWLAgY(6M?vaC%>YI2*e&j8l!V{c8vLEIVmS^L}5Bd9a z)Hxv>+yW{9=~#?=azHvIaG@MeF+Cn?gbf8uWc2a<5DIR`O zh2Fkxhc3Srx9mln6bt);CVz=X&u`g?P&=dbC@H*>Q)6udU-k5h$GLw=Zu8%Xkx?gx zM&063-H}m~`h7;N@n~iFjpX~2m|GEZ`v#9v8kC~%#4E|_u(F)o~!k5b|yAe6m--Q#!TT@4cWWse()z2NB~O$3Ruw@JCcSm^NPG6%ISPkew-4%=)?ih!K!Fh`z<1NPO1yXIHJQ!4kB^15@2abU%-(I<;rIHW99qH I&&_oI4|VDe^Z)<= literal 0 HcmV?d00001 diff --git a/lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc b/lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f5fa9360352140c89676bdb7ff593a9c675f0a5 GIT binary patch literal 4618 zcmb7H*>4-i8Q+WK^3W~c*Cvh=GjXJ%C#nXi4X0@f*a~0=39>JXC1*sfy31u}caw^6 z0mG7GqxBUuaonV-rJ#=O0JReoh#fhxU;8)gYrT~8)Ca#5ed_nkkQ60RRlCFtzd66( zeAj&A_w?i>_~eBrmB05%(m$}#`Ki!23?+C2Dk2ftl}!9sT-j4h#Zyhy(@f1vnJHO} zr_D4`iRNZJ-PFCTne}pJ&dZy5uV5Br>E{wjk@T!Y(ld(LL*$cEF|!PNOXEc?)?RVS zVXPjhPRIIdwb1dawi|25oFH75VM;L-=TG=0KRoJJC!9*0JK+aMh+PXQiwk_r-R)Kq$h zG@%iVq-IlQ21e2_lA${E=_Ct%S(+m`lAo1kWHS$Rf%I^?0CX?u<8%+u{bYdCy+99= zAx`%JJxsQ6x*zDRWE-akfZk4aaC#8von#lMhk)Kq_HcR_=)Ghgr?-$7i9w#4)y%Er zX|jJ-GPjWta)53Jq&ooVv*bBIx)WL&-9=SEuk4WETnafzim+-ojO>Bd^W+6+?d9qj zjZ<8D<8i*s_O0u9#$#cpnWnPx0P{kJS z*TNLiP_!8Z=B2~Tsoc&@%AJzt^7E1`!GEFo--G`gtP))Z#yO+;HU6}a z{PEV~g@xq$!s3??{`t*?=!YU$qH71l8af`ea5d9$LZ117WeXE$S?tLK!o(T+ZJf1RCeHHM zE)Xzr%M&&p>rT}PE!ZON;nW1R!-&x!PQPO_FRB&QSPQ5-!FpgV>xJsJ#J0oVzkfP9 z>3eiEDgmNs^w?XkkGhnh=UKMvPugo*W#9As>L@Vcb4F_5zLv#VT6m0L5UQMBkPC7@ z{Dx&+)q$=HFCA%3jn!Y#={U_OfUIt8S%?#?L7J0K$zV|mL=ZWG*d;Apx|dwODm0Uv z{-&rf?2Xk54T~z<1~l7+9T`F=?8puIQAM=_2lqg=8%m%+)vMIA-0VkSc^RWm>q8m~ zn)wqv9)g*|Kqf~6&=fF^zR{S!(YSO2%nyL{Lw}quF^5#>ZCPMJL3eOBpo!%Pa@W&4vKjKA64g{7ZK}UA}g9>Gata+k6m}f--YTlm+Z5n7;Ph`=Rdu_CF0(eL$SL zW6}S?Jh}W{a(||Sd9D_fTql^^i1#zV^7%Gz5vE60)@;feElUV|@#Fj8?Lzx-bkszg zDcj&#cK2iBkw@FCd&IGu3pVAurOg+ZHa`7g>D-w%Ys7@%(d*9lSs4C)n=jm9;8r6z zdsC(>hfO|yn9N*WzIHu1|6%K}eZh3KY&l-V`gKGj+MW4x!1%{)=3AP&UcGP~4h-x( zsH4e3PzP%#@{GbD6i>-h;DC+<)~`f3ER@l>_hqv1StrVPfWsrS7EU$+Liou{+JPO0 zjH`yBxP&O4J!APb?9ZJ5wkSyF3N{$7vk{wWiY#c&b5iy{lvqM;QlN=8G0teB_>YY$m zwps^ut<+Gjjca$4E1y07_7U8ho!V_%Uj)d6E*uXttZrytgpH0P8XjqyXx1#;6b;7k zRcJcyc#u~MP%rMw_??AE?TTjvwmQ+qGc2rX^}2O%Fg$Z-K+%eONMs>X%ADGyw7FE1 zlIGGKxe3aoophXd?WBm+iyzH4Zp}A7dDOW2b>nXzj2AQfzA(k5jTu1ICzLDN90=fc ztXUn=nW4d$yVvzwD^b+Uxqih8Lb(0AM}=Y7;!P+4US~Nar{?5dwRJtWl89B3Kq<6y zy;q^(fGn)0JU}X37MSe`2|=x)bg+lW@Onu77H;n;%>E>NEnHNfAMHOCy@U`j=fbZs z_jTjqot1=Z>D=eZ?`|!=f4y=3N-LXLeDH_HpI6c^hSod2jNpoLoOY^&R>NXC)?4%A zbX0>hNeGl-MBpNf)3q9NV5%an=kMXbnmZaNm!Jfge5F%^@}R6M!*YE{Sk9^>YZ0cv z_sroqgGrqqg=G|4>qEL75(?003QC^mnMD0xZ`w|1U_$5`w(A<0*cl?Q8st#Qiwfs;|G?Ak#<1H+3z?2cD+sRy=YK!RvZFvv6x7H`Y( z(6}G|vNcA@&lvMrx8M_J+w0CCFhr;;^*&)r?Xe9{!hQ$~-~@Z&@A3Ty;$4$G{Cx4! z-L=^FZ;ZXh{4yBN!6N$+?9&xzra9+boG=Dey_qXD`-}28=m`_&@uJKyvnlSw{J#nk zv8*^}S)NZK7wLipmzV7}N7Mi=e)Q8=hZj%HW&w?}*4z!RiA3g5n{mSyo?^f PW0)>&mvwo!JgEK$+{d>~ literal 0 HcmV?d00001 diff --git a/lerobot_aloha/common/__pycache__/rosrobot_factory.cpython-310.pyc b/lerobot_aloha/common/__pycache__/rosrobot_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d058a8e54a09126695051a28deee8c9407c5fb9b GIT binary patch literal 1948 zcma)7%WoS+7@ygf*Y+k2gaDBuv`8HAB^De_gc|Yar7}_)sJfTcW_Rqk?8D5iTM{`E zB%!GwAw57Uni4seDj<{;A+?Z*@_)>gc$1vDae;4UQ|F-yG1h+mX6HBG{N{UY)Yn%) zFot&js^82Z^e3O39tun@!V`xeJmg^?S=m3fFw7-i3S>(T6iW$IOARzj3vyNtBTx1e zUk~zDUg)Y{2#Qt_qf5xsyxb!4ato4W5Tn!Q>7=+RVI=f?rhE7bDix04OjIqYOnJBA zCQP~9h!ZA{hcit7pqVtH(D9ivInyLZ7=R_r)QM;+NxV=@~~+*qm}=;+ausx=~nb z)LRr_QOF7cXD4t^R?KwQcjB07lq4-0R$UqH@pN#VKkvX3_k(ClZFC*A@fZXTs7b9& zcv&!yNBYZF_ttW{_VD1|VtRkK``y<5?(V_Qd)>r5vWrG#1DuLz~A4O%F|Im z%B`tZn6%25uT7SHl5ukUXxce9bE6=L!ZMBG3{%U(GuniN%8nwGyDXMK^yg=XCU}I7 zV>kK~!HQXY>>8vPQyg&0nFwgGnWvbDz#1qCc;CYd(rVwFoZybsmN)U7aute4K?+)s zF$(oJP&{-KW5j>g;AtxznU~O}G^axRl(^hR9o173jhBkh)ni?(-#~PXy0U%r#md3&+v(cU{@$a* zr60RrZgp?3clUO?cWxhjHh(%vy4xH3PySbuoN`yiJafR~o$iyx9s{0G+?y5i@Q*u( zo6G6lN8Q!=^xKoH;D{c4LS+BBgcK30m>KE+pnKMm(}>AALmxSQi@Zl^M8VpF`E=uc zy1dlgzIWPYRww(vu5?%Lrr+$O%b)kiF-)r>U76;HR>TaEk#;tlefh&MZZ(?`O`zmO zHk!P^E9UIrs5ySaY52~RPsBU$N2$P+s7XSm%s7G1^08ARwjVhjlj|fAIZL01Orv=a ztZ>4eDL(I+%Ex3TE@|`v#|(1u0tez6)3V@XmBmX+)D?Y>FKePxkzvqdjyEnCc;p{L z)?@=4xF8SUSESj&Gc}U!DB)_UpXs(8Isvh5R)LE4cq>t<@m=dIotLk*S5tK zps;vf3rI2@!OfuXDX%5Uiz&OF3L4WPN16_e3CmZ(e%y3iLSN;W6aGTum!SOf#tn#G zs3Qd{vLuaQsqh*WYxG5qdI?0SpYp5H*Fi81q5(ac@=ic`Kc)OjNI3yP3L3_77ZLOz qx+2Q92&D)^W?e+(8N8T2mbv6@{^^Q2LPt6t0mmA+UpC$x!2bfu#yvg& literal 0 HcmV?d00001 diff --git a/collect_data/agilex_robot.py b/lerobot_aloha/common/agilex_robot.py similarity index 94% rename from collect_data/agilex_robot.py rename to lerobot_aloha/common/agilex_robot.py index 28a701e..0dfca78 100644 --- a/collect_data/agilex_robot.py +++ b/lerobot_aloha/common/agilex_robot.py @@ -1,17 +1,13 @@ -import yaml import cv2 import numpy as np import collections import dm_env import argparse from typing import Dict, List, Any, Optional -from collections import deque import rospy -from cv_bridge import CvBridge from std_msgs.msg import Header -from sensor_msgs.msg import Image, JointState -from nav_msgs.msg import Odometry -from rosrobot import Robot +from sensor_msgs.msg import JointState +from .rosrobot import Robot import torch import time @@ -40,9 +36,12 @@ class AgilexRobot(Robot): # print("can not get data from puppet topic") # return None - if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0: - print("can not get data from puppet topic") - return None + # 检查必要的机械臂数据是否可用 + required_arms = ['puppet_left', 'puppet_right'] + for arm_name in required_arms: + if arm_name not in self.sync_arm_queues or len(self.sync_arm_queues[arm_name]) == 0: + print(f"can not get data from {arm_name} topic") + return None # 计算最小时间戳 timestamps = [ @@ -330,12 +329,18 @@ class AgilexRobot(Robot): Returns: The actual action that was sent (may be clipped if safety checks are implemented) """ - # if not hasattr(self, 'puppet_arm_publishers'): - # # Initialize publishers on first call - # self._init_action_publishers() + # 默认速度和力矩值 + last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, + 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, + -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.03296661376953125] - last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] - last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, + 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, + 0.0, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.03296661376953125] # Convert tensor to numpy array if needed if isinstance(action, torch.Tensor): diff --git a/collect_data/rosrobot.py b/lerobot_aloha/common/robot_components.py similarity index 50% rename from collect_data/rosrobot.py rename to lerobot_aloha/common/robot_components.py index 3a2de20..7669f71 100644 --- a/collect_data/rosrobot.py +++ b/lerobot_aloha/common/robot_components.py @@ -8,29 +8,38 @@ from nav_msgs.msg import Odometry import argparse -class Robot: - def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None): +class RobotConfig: + """Configuration management for robot components""" + + def __init__(self, config_file: str): """ - 机器人基类,处理通用初始化逻辑 + Initialize robot configuration from YAML file + Args: - config_file: YAML配置文件路径 - args: 运行时参数 + config_file: Path to YAML configuration file """ - self._load_config(config_file) - self._merge_runtime_args(args) - self._init_components() - self._init_data_structures() - self.init_ros() - self.init_features() - self.warmup() - - def _load_config(self, config_file: str) -> None: - """加载YAML配置文件""" + self.config = self._load_yaml(config_file) + self._validate_config() + + def _load_yaml(self, config_file: str) -> Dict[str, Any]: + """Load configuration from YAML file""" with open(config_file, 'r') as f: - self.config = yaml.safe_load(f) - - def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None: - """合并运行时参数到配置""" + return yaml.safe_load(f) + + def _validate_config(self) -> None: + """Validate configuration completeness""" + required_sections = ['cameras', 'arm'] + for section in required_sections: + if section not in self.config: + raise ValueError(f"Missing required config section: {section}") + + def merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None: + """ + Merge runtime arguments into configuration + + Args: + args: Runtime arguments from command line + """ if args is None: return @@ -47,217 +56,56 @@ class Robot: for key, value in runtime_params.items(): if value is not None: self.config[key] = value + + def get(self, key: str, default=None) -> Any: + """Get configuration value with default fallback""" + return self.config.get(key, default) - def _init_components(self) -> None: - """初始化核心组件""" + +class RosAdapter: + """Adapter for ROS communication""" + + def __init__(self, config: RobotConfig): + """ + Initialize ROS adapter + + Args: + config: Robot configuration + """ + self.config = config self.bridge = CvBridge() self.subscribers = {} self.publishers = {} - self._validate_config() - - def _validate_config(self) -> None: - """验证配置完整性""" - required_sections = ['cameras', 'arm'] - for section in required_sections: - if section not in self.config: - raise ValueError(f"Missing required config section: {section}") - - def _init_data_structures(self) -> None: - """初始化数据结构模板方法""" - # 相机数据 - self.cameras = self.config.get('cameras', {}) - self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras} - # 深度数据 - self.use_depth_image = self.config.get('use_depth_image', False) - if self.use_depth_image: - self.sync_depth_queues = { - name: deque(maxlen=2000) - for name, cam in self.cameras.items() - if 'depth_topic_name' in cam - } + def init_ros_node(self, node_name: str = None) -> None: + """Initialize ROS node""" + if node_name is None: + node_name = self.config.get('ros_node_name', 'generic_robot_node') + + rospy.init_node(node_name, anonymous=True) - # 机械臂数据 - self.arms = self.config.get('arm', {}) - if self.config.get('control_type', '') != 'record': - # 如果不是录制模式,则仅初始化从机械臂数据队列 - self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name} - else: - self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms} - - # 机器人基座数据 - self.use_robot_base = self.config.get('use_robot_base', False) - if self.use_robot_base: - self.sync_base_queue = deque(maxlen=2000) - - def init_ros(self) -> None: - """初始化ROS订阅的模板方法""" - rospy.init_node( - f"{self.config.get('ros_node_name', 'generic_robot_node')}", - anonymous=True + def create_subscriber(self, topic: str, msg_type, callback, queue_size: int = 1000, tcp_nodelay: bool = True): + """Create a ROS subscriber""" + subscriber = rospy.Subscriber( + topic, + msg_type, + callback, + queue_size=queue_size, + tcp_nodelay=tcp_nodelay ) + return subscriber - self._setup_camera_subscribers() - self._setup_arm_subscribers_publishers() - self._setup_base_subscriber() - self._log_ros_status() - - def init_features(self): - """ - 根据YAML配置自动生成features结构 - """ - self.features = {} + def create_publisher(self, topic: str, msg_type, queue_size: int = 10): + """Create a ROS publisher""" + publisher = rospy.Publisher( + topic, + msg_type, + queue_size=queue_size + ) + return publisher - # 初始化相机特征 - self._init_camera_features() - - # 初始化机械臂特征 - self._init_state_features() - - self._init_action_features() - - # 初始化基座特征(如果启用) - if self.use_robot_base: - self._init_base_features() - import pprint - pprint.pprint(self.features, indent=4) - - - def _init_camera_features(self): - """处理所有相机特征""" - for cam_name, cam_config in self.cameras.items(): - # 普通图像 - self.features[f"observation.images.{cam_name}"] = { - "dtype": "video" if self.config.get("video", False) else "image", - "shape": cam_config.get("rgb_shape", [480, 640, 3]), - "names": ["height", "width", "channel"], - # "video_info": { - # "video.fps": cam_config.get("fps", 30.0), - # "video.codec": cam_config.get("codec", "av1"), - # "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"), - # "video.is_depth_map": False, - # "has_audio": False - # } - } - - if self.config.get("use_depth_image", False): - self.features[f"observation.images.depth_{cam_name}"] = { - "dtype": "uint16", - "shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1), - "names": ["height", "width"], - } - - - def _init_state_features(self): - state = self.config.get('state', {}) - # 状态特征 - self.features["observation.state"] = { - "dtype": "float32", - "shape": (len(state.get('motors', "")),), - "names": {"motors": state.get('motors', "")} - } - - if self.config.get('velocity'): - velocity = self.config.get('velocity', "") - self.features["observation.velocity"] = { - "dtype": "float32", - "shape": (len(velocity.get('motors', "")),), - "names": {"motors": velocity.get('motors', "")} - } - - if self.config.get('effort'): - effort = self.config.get('effort', "") - self.features["observation.effort"] = { - "dtype": "float32", - "shape": (len(effort.get('motors', "")),), - "names": {"motors": effort.get('motors', "")} - } - - - - def _init_action_features(self): - action = self.config.get('action', {}) - # 状态特征 - self.features["action"] = { - "dtype": "float32", - "shape": (len(action.get('motors', "")),), - "names": {"motors": action.get('motors', "")} - } - - def _init_base_features(self): - """处理基座特征""" - self.features["observation.base_vel"] = { - "dtype": "float32", - "shape": (2,), - "names": ["linear_x", "angular_z"] - } - - - def _setup_camera_subscribers(self) -> None: - """设置相机订阅者""" - for cam_name, cam_config in self.cameras.items(): - if 'img_topic_name' in cam_config: - self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber( - cam_config['img_topic_name'], - Image, - self._make_camera_callback(cam_name, is_depth=False), - queue_size=1000, - tcp_nodelay=True - ) - - if self.use_depth_image and 'depth_topic_name' in cam_config: - self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber( - cam_config['depth_topic_name'], - Image, - self._make_camera_callback(cam_name, is_depth=True), - queue_size=1000, - tcp_nodelay=True - ) - - def _setup_arm_subscribers_publishers(self) -> None: - """设置机械臂订阅者""" - # 当为record模式时,主从机械臂都需要订阅 - # 否则只订阅从机械臂,但向主机械臂发布 - if self.config.get('control_type', '') == 'record': - for arm_name, arm_config in self.arms.items(): - if 'topic_name' in arm_config: - self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber( - arm_config['topic_name'], - JointState, - self._make_arm_callback(arm_name), - queue_size=1000, - tcp_nodelay=True - ) - else: - for arm_name, arm_config in self.arms.items(): - if 'puppet' in arm_name: - self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber( - arm_config['topic_name'], - JointState, - self._make_arm_callback(arm_name), - queue_size=1000, - tcp_nodelay=True - ) - if 'master' in arm_name: - self.publishers[f"arm_{arm_name}"] = rospy.Publisher( - arm_config['topic_name'], - JointState, - queue_size=10 - ) - - def _setup_base_subscriber(self) -> None: - """设置基座订阅者""" - if self.use_robot_base and 'robot_base' in self.config: - self.subscribers['base'] = rospy.Subscriber( - self.config['robot_base']['topic_name'], - Odometry, - self.robot_base_callback, - queue_size=1000, - tcp_nodelay=True - ) - - def _log_ros_status(self) -> None: - """记录ROS状态""" + def log_status(self) -> None: + """Log ROS connection status""" rospy.loginfo("\n=== ROS订阅状态 ===") rospy.loginfo(f"已初始化节点: {rospy.get_name()}") rospy.loginfo("活跃的订阅者:") @@ -265,8 +113,74 @@ class Robot: rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}") rospy.loginfo("=================") + +class RobotSensors: + """Management of robot sensors (cameras, depth sensors)""" + + def __init__(self, config: RobotConfig, ros_adapter: RosAdapter): + """ + Initialize robot sensors + + Args: + config: Robot configuration + ros_adapter: ROS communication adapter + """ + self.config = config + self.ros_adapter = ros_adapter + self.bridge = ros_adapter.bridge + + # Camera data + self.cameras = config.get('cameras', {}) + self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras} + + # Depth data + self.use_depth_image = config.get('use_depth_image', False) + if self.use_depth_image: + self.sync_depth_queues = { + name: deque(maxlen=2000) + for name, cam in self.cameras.items() + if 'depth_topic_name' in cam + } + + # Robot base data + self.use_robot_base = config.get('use_robot_base', False) + if self.use_robot_base: + self.sync_base_queue = deque(maxlen=2000) + + def setup_subscribers(self) -> None: + """Set up ROS subscribers for sensors""" + self._setup_camera_subscribers() + if self.use_robot_base: + self._setup_base_subscriber() + + def _setup_camera_subscribers(self) -> None: + """Set up camera subscribers""" + for cam_name, cam_config in self.cameras.items(): + if 'img_topic_name' in cam_config: + self.ros_adapter.subscribers[f"camera_{cam_name}"] = self.ros_adapter.create_subscriber( + cam_config['img_topic_name'], + Image, + self._make_camera_callback(cam_name, is_depth=False) + ) + + if self.use_depth_image and 'depth_topic_name' in cam_config: + self.ros_adapter.subscribers[f"depth_{cam_name}"] = self.ros_adapter.create_subscriber( + cam_config['depth_topic_name'], + Image, + self._make_camera_callback(cam_name, is_depth=True) + ) + + def _setup_base_subscriber(self) -> None: + """Set up base subscriber""" + if 'robot_base' in self.config.config: + self.ros_adapter.subscribers['base'] = self.ros_adapter.create_subscriber( + self.config.get('robot_base')['topic_name'], + Odometry, + self.robot_base_callback + ) + def _make_camera_callback(self, cam_name: str, is_depth: bool = False): - """生成相机回调函数工厂方法""" + """Generate camera callback factory method""" def callback(msg): try: target_queue = ( @@ -281,8 +195,105 @@ class Robot: rospy.logerr(f"Camera {cam_name} callback error: {str(e)}") return callback + def robot_base_callback(self, msg): + """Base callback default implementation""" + if len(self.sync_base_queue) >= 2000: + self.sync_base_queue.popleft() + self.sync_base_queue.append(msg) + + def init_features(self) -> Dict[str, Any]: + """Initialize sensor features""" + features = {} + + # Initialize camera features + self._init_camera_features(features) + + # Initialize base features (if enabled) + if self.use_robot_base: + self._init_base_features(features) + + return features + + def _init_camera_features(self, features: Dict[str, Any]) -> None: + """Process all camera features""" + for cam_name, cam_config in self.cameras.items(): + # Regular images + features[f"observation.images.{cam_name}"] = { + "dtype": "video" if self.config.get("video", False) else "image", + "shape": cam_config.get("rgb_shape", [480, 640, 3]), + "names": ["height", "width", "channel"], + } + + if self.config.get("use_depth_image", False): + features[f"observation.images.depth_{cam_name}"] = { + "dtype": "uint16", + "shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1), + "names": ["height", "width"], + } + + def _init_base_features(self, features: Dict[str, Any]) -> None: + """Process base features""" + features["observation.base_vel"] = { + "dtype": "float32", + "shape": (2,), + "names": ["linear_x", "angular_z"] + } + + +class RobotActuators: + """Management of robot actuators (arms, base)""" + + def __init__(self, config: RobotConfig, ros_adapter: RosAdapter): + """ + Initialize robot actuators + + Args: + config: Robot configuration + ros_adapter: ROS communication adapter + """ + self.config = config + self.ros_adapter = ros_adapter + + # Arm data + self.arms = config.get('arm', {}) + if config.get('control_type', '') != 'record': + # If not in record mode, only initialize puppet arm queues + self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name} + else: + self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms} + + def setup_subscribers_publishers(self) -> None: + """Set up ROS subscribers and publishers for actuators""" + self._setup_arm_subscribers_publishers() + + def _setup_arm_subscribers_publishers(self) -> None: + """Set up arm subscribers and publishers""" + # When in record mode, subscribe to both master and puppet arms + # Otherwise only subscribe to puppet arms, but publish to master arms + if self.config.get('control_type', '') == 'record': + for arm_name, arm_config in self.arms.items(): + if 'topic_name' in arm_config: + self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber( + arm_config['topic_name'], + JointState, + self._make_arm_callback(arm_name) + ) + else: + for arm_name, arm_config in self.arms.items(): + if 'puppet' in arm_name: + self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber( + arm_config['topic_name'], + JointState, + self._make_arm_callback(arm_name) + ) + if 'master' in arm_name: + self.ros_adapter.publishers[f"arm_{arm_name}"] = self.ros_adapter.create_publisher( + arm_config['topic_name'], + JointState + ) + def _make_arm_callback(self, arm_name: str): - """生成机械臂回调函数工厂方法""" + """Generate arm callback factory method""" def callback(msg): try: if len(self.sync_arm_queues[arm_name]) >= 2000: @@ -292,17 +303,74 @@ class Robot: rospy.logerr(f"Arm {arm_name} callback error: {str(e)}") return callback - def robot_base_callback(self, msg): - """基座回调默认实现""" - if len(self.sync_base_queue) >= 2000: - self.sync_base_queue.popleft() - self.sync_base_queue.append(msg) + def init_features(self) -> Dict[str, Any]: + """Initialize actuator features""" + features = {} + + # Initialize arm features + self._init_state_features(features) + self._init_action_features(features) + + return features + + def _init_state_features(self, features: Dict[str, Any]) -> None: + """Initialize state features""" + state = self.config.get('state', {}) + # State features + features["observation.state"] = { + "dtype": "float32", + "shape": (len(state.get('motors', "")),), + "names": {"motors": state.get('motors', "")} + } - def warmup(self, timeout: float = 10.0) -> bool: - """Wait until all data queues have at least 20 messages. + if self.config.get('velocity'): + velocity = self.config.get('velocity', "") + features["observation.velocity"] = { + "dtype": "float32", + "shape": (len(velocity.get('motors', "")),), + "names": {"motors": velocity.get('motors', "")} + } + + if self.config.get('effort'): + effort = self.config.get('effort', "") + features["observation.effort"] = { + "dtype": "float32", + "shape": (len(effort.get('motors', "")),), + "names": {"motors": effort.get('motors', "")} + } + + def _init_action_features(self, features: Dict[str, Any]) -> None: + """Initialize action features""" + action = self.config.get('action', {}) + features["action"] = { + "dtype": "float32", + "shape": (len(action.get('motors', "")),), + "names": {"motors": action.get('motors', "")} + } + + +class RobotDataManager: + """Management of robot data collection and synchronization""" + + def __init__(self, config: RobotConfig, sensors: RobotSensors, actuators: RobotActuators): + """ + Initialize robot data manager Args: - timeout: Maximum time to wait in seconds before giving up + config: Robot configuration + sensors: Robot sensors component + actuators: Robot actuators component + """ + self.config = config + self.sensors = sensors + self.actuators = actuators + + def warmup(self, timeout: float = 10.0) -> bool: + """ + Wait until all data queues have sufficient messages + + Args: + timeout: Maximum time to wait in seconds Returns: bool: True if warmup succeeded, False if timed out @@ -323,31 +391,24 @@ class Robot: all_ready = True # Check camera image queues - for cam_name in self.cameras: - if len(self.sync_img_queues[cam_name]) < 50: - rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)") + for cam_name in self.sensors.cameras: + if len(self.sensors.sync_img_queues[cam_name]) < 50: + rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)") all_ready = False break # Check depth queues if enabled - if self.use_depth_image: - for cam_name in self.sync_depth_queues: - if len(self.sync_depth_queues[cam_name]) < 50: - rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)") + if self.sensors.use_depth_image: + for cam_name in self.sensors.sync_depth_queues: + if len(self.sensors.sync_depth_queues[cam_name]) < 50: + rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)") all_ready = False break - # # Check arm queues - # for arm_name in self.arms: - # if len(self.sync_arm_queues[arm_name]) < 20: - # rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)") - # all_ready = False - # break - # Check base queue if enabled - if self.use_robot_base: - if len(self.sync_base_queue) < 20: - rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)") + if self.sensors.use_robot_base: + if len(self.sensors.sync_base_queue) < 20: + rospy.loginfo(f"Waiting for base (current: {len(self.sensors.sync_base_queue)}/20)") all_ready = False # If all queues are ready, return success @@ -357,16 +418,4 @@ class Robot: rate.sleep() - return False - - - - - - def get_frame(self) -> Optional[Dict[str, Any]]: - """获取同步帧数据的模板方法""" - raise NotImplementedError("Subclasses must implement get_frame()") - - def process(self) -> tuple: - """主处理循环的模板方法""" - raise NotImplementedError("Subclasses must implement process()") + return False \ No newline at end of file diff --git a/lerobot_aloha/common/rosrobot.py b/lerobot_aloha/common/rosrobot.py new file mode 100644 index 0000000..30f80dd --- /dev/null +++ b/lerobot_aloha/common/rosrobot.py @@ -0,0 +1,136 @@ +import yaml +from typing import Dict, Any, Optional, List +import argparse +from .robot_components import RobotConfig, RosAdapter, RobotSensors, RobotActuators, RobotDataManager + + +class Robot: + def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None): + """ + 机器人基类,处理通用初始化逻辑 + Args: + config_file: YAML配置文件路径 + args: 运行时参数 + """ + # 初始化组件 + self.config = RobotConfig(config_file) + self.config.merge_runtime_args(args) + self.ros_adapter = RosAdapter(self.config) + self.sensors = RobotSensors(self.config, self.ros_adapter) + self.actuators = RobotActuators(self.config, self.ros_adapter) + self.data_manager = RobotDataManager(self.config, self.sensors, self.actuators) + + # 初始化ROS和特征 + self.init_ros() + self.init_features() + self.warmup() + + def get(self, key: str, default=None) -> Any: + """获取配置值""" + return self.config.get(key, default) + + @property + def bridge(self): + """获取CV桥接器""" + return self.ros_adapter.bridge + + @property + def subscribers(self): + """获取订阅者""" + return self.ros_adapter.subscribers + + @property + def publishers(self): + """获取发布者""" + return self.ros_adapter.publishers + + @property + def cameras(self): + """获取相机配置""" + return self.sensors.cameras + + @property + def arms(self): + """获取机械臂配置""" + return self.actuators.arms + + @property + def sync_img_queues(self): + """获取图像队列""" + return self.sensors.sync_img_queues + + @property + def sync_depth_queues(self): + """获取深度图像队列""" + return self.sensors.sync_depth_queues if hasattr(self.sensors, 'sync_depth_queues') else {} + + @property + def sync_arm_queues(self): + """获取机械臂队列""" + return self.actuators.sync_arm_queues + + @property + def sync_base_queue(self): + """获取基座队列""" + return self.sensors.sync_base_queue if hasattr(self.sensors, 'sync_base_queue') else None + + @property + def use_depth_image(self): + """是否使用深度图像""" + return self.sensors.use_depth_image + + @property + def use_robot_base(self): + """是否使用机器人基座""" + return self.sensors.use_robot_base + + def init_ros(self) -> None: + """初始化ROS订阅的模板方法""" + self.ros_adapter.init_ros_node() + + # 设置传感器和执行器的订阅者和发布者 + self.sensors.setup_subscribers() + self.actuators.setup_subscribers_publishers() + + # 记录ROS状态 + self.ros_adapter.log_status() + + def init_features(self): + """ + 根据YAML配置自动生成features结构 + """ + # 合并传感器和执行器的特征 + self.features = {} + self.features.update(self.sensors.init_features()) + self.features.update(self.actuators.init_features()) + + import pprint + pprint.pprint(self.features, indent=4) + + + + + + + def warmup(self, timeout: float = 10.0) -> bool: + """Wait until all data queues have at least 20 messages. + + Args: + timeout: Maximum time to wait in seconds before giving up + + Returns: + bool: True if warmup succeeded, False if timed out + """ + return self.data_manager.warmup(timeout) + + + + + + def get_frame(self) -> Optional[Dict[str, Any]]: + """获取同步帧数据的模板方法""" + raise NotImplementedError("Subclasses must implement get_frame()") + + def process(self) -> tuple: + """主处理循环的模板方法""" + raise NotImplementedError("Subclasses must implement process()") diff --git a/lerobot_aloha/common/rosrobot_factory.py b/lerobot_aloha/common/rosrobot_factory.py new file mode 100644 index 0000000..9409586 --- /dev/null +++ b/lerobot_aloha/common/rosrobot_factory.py @@ -0,0 +1,59 @@ +import yaml +import argparse +from typing import Dict, List, Any, Optional, Type +from .rosrobot import Robot +from .agilex_robot import AgilexRobot + + +class RobotFactory: + """Factory for creating robot instances based on configuration""" + + # 注册表,用于存储可用的机器人类型 + _registry = {} + + @classmethod + def register(cls, robot_type: str, robot_class: Type[Robot]) -> None: + """ + 注册新的机器人类型 + + Args: + robot_type: 机器人类型标识符 + robot_class: 机器人类实现 + """ + cls._registry[robot_type] = robot_class + + @classmethod + def create(cls, config_file: str, args: Optional[argparse.Namespace] = None) -> Robot: + """ + 根据配置文件自动创建合适的机器人实例 + + Args: + config_file: 配置文件路径 + args: 运行时参数 + + Returns: + Robot: 创建的机器人实例 + + Raises: + ValueError: 如果指定的机器人类型不受支持 + """ + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + robot_type = config.get('robot_type', 'agilex') + + # 如果注册表为空,注册默认机器人类型 + if not cls._registry: + cls.register('agilex', AgilexRobot) + cls.register('aloha_agilex', AgilexRobot) # 别名支持 + + # 从注册表中查找机器人类 + if robot_type in cls._registry: + return cls._registry[robot_type](config_file, args) + else: + raise ValueError(f"Unsupported robot type: {robot_type}. Available types: {list(cls._registry.keys())}") + + +# 注册可用的机器人类型 +RobotFactory.register('agilex', AgilexRobot) +RobotFactory.register('aloha_agilex', AgilexRobot) # 别名支持 diff --git a/lerobot_aloha/configs/agilex.yaml b/lerobot_aloha/configs/agilex.yaml new file mode 100644 index 0000000..703b7e2 --- /dev/null +++ b/lerobot_aloha/configs/agilex.yaml @@ -0,0 +1,146 @@ +robot_type: aloha_agilex +ros_node_name: record_episodes +cameras: + cam_front: + img_topic_name: /camera_f/color/image_raw + depth_topic_name: /camera_f/depth/image_raw + width: 480 + height: 640 + rgb_shape: [480, 640, 3] + cam_left: + img_topic_name: /camera_l/color/image_raw + depth_topic_name: /camera_l/depth/image_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + cam_right: + img_topic_name: /camera_r/color/image_raw + depth_topic_name: /camera_r/depth/image_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + cam_high: + img_topic_name: /camera/color/image_raw + depth_topic_name: /camera/depth/image_rect_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + +arm: + master_left: + topic_name: /master/joint_left + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none" + ] + master_right: + topic_name: /master/joint_right + motors: [ + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + puppet_left: + topic_name: /puppet/joint_left + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none" + ] + puppet_right: + topic_name: /puppet/joint_right + motors: [ + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +# follow the joint name in ros +state: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +velocity: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +effort: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +action: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] diff --git a/lerobot_aloha/inference.py b/lerobot_aloha/inference.py new file mode 100644 index 0000000..34f7f52 --- /dev/null +++ b/lerobot_aloha/inference.py @@ -0,0 +1,769 @@ +#!/home/lin/software/miniconda3/envs/aloha/bin/python +# -- coding: UTF-8 +""" +#!/usr/bin/python3 +""" + +import torch +import numpy as np +import os +import pickle +import argparse +from einops import rearrange +import collections +from collections import deque + +import rospy +from std_msgs.msg import Header +from geometry_msgs.msg import Twist +from sensor_msgs.msg import JointState, Image +from nav_msgs.msg import Odometry +from cv_bridge import CvBridge +import time +import threading +import math +import threading + + + + +import sys +sys.path.append("./") + +SEED = 42 +torch.manual_seed(SEED) +np.random.seed(SEED) + +task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']} + +inference_thread = None +inference_lock = threading.Lock() +inference_actions = None +inference_timestep = None + + +def actions_interpolation(args, pre_action, actions, stats): + steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0) + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + result = [pre_action] + post_action = post_process(actions[0]) + # print("pre_action:", pre_action[7:]) + # print("actions_interpolation1:", post_action[:, 7:]) + max_diff_index = 0 + max_diff = -1 + for i in range(post_action.shape[0]): + diff = 0 + for j in range(pre_action.shape[0]): + if j == 6 or j == 13: + continue + diff += math.fabs(pre_action[j] - post_action[i][j]) + if diff > max_diff: + max_diff = diff + max_diff_index = i + + for i in range(max_diff_index, post_action.shape[0]): + step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])]) + inter = np.linspace(result[-1], post_action[i], step+2) + result.extend(inter[1:]) + while len(result) < args.chunk_size+1: + result.append(result[-1]) + result = np.array(result)[1:args.chunk_size+1] + # print("actions_interpolation2:", result.shape, result[:, 7:]) + result = pre_process(result) + result = result[np.newaxis, :] + return result + + +def get_model_config(args): + # 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。 + set_seed(1) + + # 如果是ACT策略 + # fixed parameters + if args.policy_class == 'ACT': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': args.chunk_size, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base, + 'kl_weight': args.kl_weight, # kl散度权重 + 'hidden_dim': args.hidden_dim, # 隐藏层维度 + 'dim_feedforward': args.dim_feedforward, + 'enc_layers': args.enc_layers, + 'dec_layers': args.dec_layers, + 'nheads': args.nheads, + 'dropout': args.dropout, + 'pre_norm': args.pre_norm + } + elif args.policy_class == 'CNNMLP': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': 1, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base + } + + elif args.policy_class == 'Diffusion': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': args.chunk_size, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base, + 'observation_horizon': args.observation_horizon, + 'action_horizon': args.action_horizon, + 'num_inference_timesteps': args.num_inference_timesteps, + 'ema_power': args.ema_power + } + else: + raise NotImplementedError + + config = { + 'ckpt_dir': args.ckpt_dir, + 'ckpt_name': args.ckpt_name, + 'ckpt_stats_name': args.ckpt_stats_name, + 'episode_len': args.max_publish_step, + 'state_dim': args.state_dim, + 'policy_class': args.policy_class, + 'policy_config': policy_config, + 'temporal_agg': args.temporal_agg, + 'camera_names': task_config['camera_names'], + } + return config + + +def make_policy(policy_class, policy_config): + if policy_class == 'ACT': + policy = ACTPolicy(policy_config) + elif policy_class == 'CNNMLP': + policy = CNNMLPPolicy(policy_config) + elif policy_class == 'Diffusion': + policy = DiffusionPolicy(policy_config) + else: + raise NotImplementedError + return policy + + +def get_image(observation, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w') + + curr_images.append(curr_image) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def get_depth_image(observation, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_images.append(observation['images_depth'][cam_name]) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def inference_process(args, config, ros_operator, policy, stats, t, pre_action): + global inference_lock + global inference_actions + global inference_timestep + print_flag = True + pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"] + rate = rospy.Rate(args.publish_rate) + while True and not rospy.is_shutdown(): + result = ros_operator.get_frame() + if not result: + if print_flag: + print("syn fail") + print_flag = False + rate.sleep() + continue + print_flag = True + (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, robot_base) = result + obs = collections.OrderedDict() + image_dict = dict() + + image_dict[config['camera_names'][0]] = img_front + image_dict[config['camera_names'][1]] = img_left + image_dict[config['camera_names'][2]] = img_right + + + obs['images'] = image_dict + + if args.use_depth_image: + image_depth_dict = dict() + image_depth_dict[config['camera_names'][0]] = img_front_depth + image_depth_dict[config['camera_names'][1]] = img_left_depth + image_depth_dict[config['camera_names'][2]] = img_right_depth + obs['images_depth'] = image_depth_dict + + obs['qpos'] = np.concatenate( + (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0) + obs['qvel'] = np.concatenate( + (np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0) + obs['effort'] = np.concatenate( + (np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0) + if args.use_robot_base: + obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z] + obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0) + else: + obs['base_vel'] = [0.0, 0.0] + # qpos_numpy = np.array(obs['qpos']) + + # 归一化处理qpos 并转到cuda + qpos = pre_pos_process(obs['qpos']) + qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) + # 当前图像curr_image获取图像 + curr_image = get_image(obs, config['camera_names']) + curr_depth_image = None + if args.use_depth_image: + curr_depth_image = get_depth_image(obs, config['camera_names']) + start_time = time.time() + all_actions = policy(curr_image, curr_depth_image, qpos) + end_time = time.time() + print("model cost time: ", end_time -start_time) + inference_lock.acquire() + inference_actions = all_actions.cpu().detach().numpy() + if pre_action is None: + pre_action = obs['qpos'] + # print("obs['qpos']:", obs['qpos'][7:]) + if args.use_actions_interpolation: + inference_actions = actions_interpolation(args, pre_action, inference_actions, stats) + inference_timestep = t + inference_lock.release() + break + + +def model_inference(args, config, ros_operator, save_episode=True): + global inference_lock + global inference_actions + global inference_timestep + global inference_thread + set_seed(1000) + + # 1 创建模型数据 继承nn.Module + policy = make_policy(config['policy_class'], config['policy_config']) + # print("model structure\n", policy.model) + + # 2 加载模型权重 + ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name']) + state_dict = torch.load(ckpt_path) + new_state_dict = {} + for key, value in state_dict.items(): + if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]: + continue + if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]: + continue + new_state_dict[key] = value + loading_status = policy.deserialize(new_state_dict) + if not loading_status: + print("ckpt path not exist") + return False + + # 3 模型设置为cuda模式和验证模式 + policy.cuda() + policy.eval() + + # 4 加载统计值 + stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name']) + # 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维 + with open(stats_path, 'rb') as f: + stats = pickle.load(f) + + # 数据预处理和后处理函数定义 + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + + max_publish_step = config['episode_len'] + chunk_size = config['policy_config']['chunk_size'] + + # 发布基础的姿态 + left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875] + right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875] + left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] + right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883] + + ros_operator.puppet_arm_publish_continuous(left0, right0) + input("Enter any key to continue :") + ros_operator.puppet_arm_publish_continuous(left1, right1) + action = None + # 推理 + with torch.inference_mode(): + while True and not rospy.is_shutdown(): + # 每个回合的步数 + t = 0 + max_t = 0 + rate = rospy.Rate(args.publish_rate) + if config['temporal_agg']: + all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']]) + while t < max_publish_step and not rospy.is_shutdown(): + # start_time = time.time() + # query policy + if config['policy_class'] == "ACT": + if t >= max_t: + pre_action = action + inference_thread = threading.Thread(target=inference_process, + args=(args, config, ros_operator, + policy, stats, t, pre_action)) + inference_thread.start() + inference_thread.join() + inference_lock.acquire() + if inference_actions is not None: + inference_thread = None + all_actions = inference_actions + inference_actions = None + max_t = t + args.pos_lookahead_step + if config['temporal_agg']: + all_time_actions[[t], t:t + chunk_size] = all_actions + inference_lock.release() + if config['temporal_agg']: + actions_for_curr_step = all_time_actions[:, t] + actions_populated = np.all(actions_for_curr_step != 0, axis=1) + actions_for_curr_step = actions_for_curr_step[actions_populated] + k = 0.01 + exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + exp_weights = exp_weights / exp_weights.sum() + exp_weights = exp_weights[:, np.newaxis] + raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True) + else: + if args.pos_lookahead_step != 0: + raw_action = all_actions[:, t % args.pos_lookahead_step] + else: + raw_action = all_actions[:, t % chunk_size] + else: + raise NotImplementedError + action = post_process(raw_action[0]) + left_action = action[:7] # 取7维度 + right_action = action[7:14] + ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread + if args.use_robot_base: + vel_action = action[14:16] + ros_operator.robot_base_publish(vel_action) + t += 1 + # end_time = time.time() + # print("publish: ", t) + # print("time:", end_time - start_time) + # print("left_action:", left_action) + # print("right_action:", right_action) + rate.sleep() + + +class RosOperator: + def __init__(self, args): + self.robot_base_deque = None + self.puppet_arm_right_deque = None + self.puppet_arm_left_deque = None + self.img_front_deque = None + self.img_right_deque = None + self.img_left_deque = None + self.img_front_depth_deque = None + self.img_right_depth_deque = None + self.img_left_depth_deque = None + self.bridge = None + self.puppet_arm_left_publisher = None + self.puppet_arm_right_publisher = None + self.robot_base_publisher = None + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_lock = None + self.args = args + self.ctrl_state = False + self.ctrl_state_lock = threading.Lock() + self.init() + self.init_ros() + + def init(self): + self.bridge = CvBridge() + self.img_left_deque = deque() + self.img_right_deque = deque() + self.img_front_deque = deque() + self.img_left_depth_deque = deque() + self.img_right_depth_deque = deque() + self.img_front_depth_deque = deque() + self.puppet_arm_left_deque = deque() + self.puppet_arm_right_deque = deque() + self.robot_base_deque = deque() + self.puppet_arm_publish_lock = threading.Lock() + self.puppet_arm_publish_lock.acquire() + + def puppet_arm_publish(self, left, right): + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right + self.puppet_arm_right_publisher.publish(joint_state_msg) + + def robot_base_publish(self, vel): + vel_msg = Twist() + vel_msg.linear.x = vel[0] + vel_msg.linear.y = 0 + vel_msg.linear.z = 0 + vel_msg.angular.x = 0 + vel_msg.angular.y = 0 + vel_msg.angular.z = vel[1] + self.robot_base_publisher.publish(vel_msg) + + def puppet_arm_publish_continuous(self, left, right): + rate = rospy.Rate(self.args.publish_rate) + left_arm = None + right_arm = None + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))] + right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))] + flag = True + step = 0 + while flag and not rospy.is_shutdown(): + if self.puppet_arm_publish_lock.acquire(False): + return + left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))] + right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))] + flag = False + for i in range(len(left)): + if left_diff[i] < self.args.arm_steps_length[i]: + left_arm[i] = left[i] + else: + left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] + flag = True + for i in range(len(right)): + if right_diff[i] < self.args.arm_steps_length[i]: + right_arm[i] = right[i] + else: + right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i] + flag = True + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = left_arm + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right_arm + self.puppet_arm_right_publisher.publish(joint_state_msg) + step += 1 + print("puppet_arm_publish_continuous:", step) + rate.sleep() + + def puppet_arm_publish_linear(self, left, right): + num_step = 100 + rate = rospy.Rate(200) + + left_arm = None + right_arm = None + + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + + traj_left_list = np.linspace(left_arm, left, num_step) + traj_right_list = np.linspace(right_arm, right, num_step) + + for i in range(len(traj_left_list)): + traj_left = traj_left_list[i] + traj_right = traj_right_list[i] + traj_left[-1] = left[-1] + traj_right[-1] = right[-1] + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = traj_left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = traj_right + self.puppet_arm_right_publisher.publish(joint_state_msg) + rate.sleep() + + def puppet_arm_publish_continuous_thread(self, left, right): + if self.puppet_arm_publish_thread is not None: + self.puppet_arm_publish_lock.release() + self.puppet_arm_publish_thread.join() + self.puppet_arm_publish_lock.acquire(False) + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right)) + self.puppet_arm_publish_thread.start() + + def get_frame(self): + if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \ + (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)): + return False + if self.args.use_depth_image: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(), + self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()]) + else: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()]) + + if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): + return False + + while self.img_left_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_deque.popleft() + img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough') + + while self.img_right_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_deque.popleft() + img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough') + + while self.img_front_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_deque.popleft() + img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough') + + while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_left_deque.popleft() + puppet_arm_left = self.puppet_arm_left_deque.popleft() + + while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_right_deque.popleft() + puppet_arm_right = self.puppet_arm_right_deque.popleft() + + img_left_depth = None + if self.args.use_depth_image: + while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_depth_deque.popleft() + img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough') + + img_right_depth = None + if self.args.use_depth_image: + while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_depth_deque.popleft() + img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough') + + img_front_depth = None + if self.args.use_depth_image: + while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_depth_deque.popleft() + img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough') + + robot_base = None + if self.args.use_robot_base: + while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: + self.robot_base_deque.popleft() + robot_base = self.robot_base_deque.popleft() + + return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, robot_base) + + def img_left_callback(self, msg): + if len(self.img_left_deque) >= 2000: + self.img_left_deque.popleft() + self.img_left_deque.append(msg) + + def img_right_callback(self, msg): + if len(self.img_right_deque) >= 2000: + self.img_right_deque.popleft() + self.img_right_deque.append(msg) + + def img_front_callback(self, msg): + if len(self.img_front_deque) >= 2000: + self.img_front_deque.popleft() + self.img_front_deque.append(msg) + + def img_left_depth_callback(self, msg): + if len(self.img_left_depth_deque) >= 2000: + self.img_left_depth_deque.popleft() + self.img_left_depth_deque.append(msg) + + def img_right_depth_callback(self, msg): + if len(self.img_right_depth_deque) >= 2000: + self.img_right_depth_deque.popleft() + self.img_right_depth_deque.append(msg) + + def img_front_depth_callback(self, msg): + if len(self.img_front_depth_deque) >= 2000: + self.img_front_depth_deque.popleft() + self.img_front_depth_deque.append(msg) + + def puppet_arm_left_callback(self, msg): + if len(self.puppet_arm_left_deque) >= 2000: + self.puppet_arm_left_deque.popleft() + self.puppet_arm_left_deque.append(msg) + + def puppet_arm_right_callback(self, msg): + if len(self.puppet_arm_right_deque) >= 2000: + self.puppet_arm_right_deque.popleft() + self.puppet_arm_right_deque.append(msg) + + def robot_base_callback(self, msg): + if len(self.robot_base_deque) >= 2000: + self.robot_base_deque.popleft() + self.robot_base_deque.append(msg) + + def ctrl_callback(self, msg): + self.ctrl_state_lock.acquire() + self.ctrl_state = msg.data + self.ctrl_state_lock.release() + + def get_ctrl_state(self): + self.ctrl_state_lock.acquire() + state = self.ctrl_state + self.ctrl_state_lock.release() + return state + + def init_ros(self): + rospy.init_node('joint_state_publisher', anonymous=True) + rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True) + if self.args.use_depth_image: + rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True) + self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10) + self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10) + self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10) + + +def get_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) + parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False) + parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False) + parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False) + parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False) + parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False) + parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False) + parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False) + parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False) + parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False) + parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False) + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False) + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features", required=False) + parser.add_argument('--masks', action='store_true', + help="Train segmentation head if the flag is provided") + parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False) + parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False) + parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False) + parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False) + + parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False) + parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False) + parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False) + parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False) + parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False) + parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False) + parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False) + parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False) + parser.add_argument('--pre_norm', action='store_true', required=False) + + parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', + default='/camera_f/color/image_raw', required=False) + parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', + default='/camera_l/color/image_raw', required=False) + parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', + default='/camera_r/color/image_raw', required=False) + + parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic', + default='/camera_f/depth/image_raw', required=False) + parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic', + default='/camera_l/depth/image_raw', required=False) + parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic', + default='/camera_r/depth/image_raw', required=False) + + parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic', + default='/master/joint_left', required=False) + parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic', + default='/master/joint_right', required=False) + parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', + default='/puppet/joint_left', required=False) + parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', + default='/puppet/joint_right', required=False) + + parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', + default='/odom_raw', required=False) + parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic', + default='/cmd_vel', required=False) + parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', + default=False, required=False) + parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate', + default=40, required=False) + parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step', + default=0, required=False) + parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', + default=32, required=False) + parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length', + default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False) + + parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation', + default=False, required=False) + parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image', + default=False, required=False) + + # for Diffusion + parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False) + parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False) + parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False) + parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False) + args = parser.parse_args() + return args + + +def main(): + args = get_arguments() + ros_operator = RosOperator(args) + config = get_model_config(args) + model_inference(args, config, ros_operator, save_episode=True) + + +if __name__ == '__main__': + main() +# python act/inference.py --ckpt_dir ~/train0314/ \ No newline at end of file diff --git a/lerobot_aloha/read_parquet.py b/lerobot_aloha/read_parquet.py new file mode 100644 index 0000000..577a1e3 --- /dev/null +++ b/lerobot_aloha/read_parquet.py @@ -0,0 +1,33 @@ +import pandas as pd + +def read_and_print_parquet_row(file_path, row_index=0): + """ + 读取Parquet文件并打印指定行的数据 + + 参数: + file_path (str): Parquet文件路径 + row_index (int): 要打印的行索引(默认为第0行) + """ + try: + # 读取Parquet文件 + df = pd.read_parquet(file_path) + + # 检查行索引是否有效 + if row_index >= len(df): + print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)") + return + + # 打印指定行数据 + print(f"文件: {file_path}") + print(f"第 {row_index} 行数据:\n{'-'*30}") + print(df.iloc[row_index]) + + except FileNotFoundError: + print(f"错误: 文件 {file_path} 不存在") + except Exception as e: + print(f"读取失败: {str(e)}") + +# 示例用法 +if __name__ == "__main__": + file_path = "example.parquet" # 替换为你的Parquet文件路径 + read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行 diff --git a/lerobot_aloha/replay_data.py b/lerobot_aloha/replay_data.py new file mode 100644 index 0000000..6c880dc --- /dev/null +++ b/lerobot_aloha/replay_data.py @@ -0,0 +1,112 @@ +#coding=utf-8 +import os +import numpy as np +import cv2 +import h5py +import argparse +import rospy + +from cv_bridge import CvBridge +from std_msgs.msg import Header +from sensor_msgs.msg import Image, JointState +from geometry_msgs.msg import Twist +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + + +def main(args): + rospy.init_node("replay_node") + bridge = CvBridge() + # img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10) + # img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10) + # img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10) + + # puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10) + # puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10) + + master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10) + master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10) + + # robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10) + + + # dataset_dir = args.dataset_dir + # episode_idx = args.episode_idx + # task_name = args.task_name + # dataset_name = f'episode_{episode_idx}' + + dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode]) + actions = dataset.hf_dataset.select_columns("action") + velocitys = dataset.hf_dataset.select_columns("observation.velocity") + efforts = dataset.hf_dataset.select_columns("observation.effort") + + origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279] + origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279] + + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称 + twist_msg = Twist() + + rate = rospy.Rate(args.fps) + + # qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name) + + + last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647] + last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + rate = rospy.Rate(50) + for idx in range(len(actions)): + action = actions[idx]['action'].detach().cpu().numpy() + velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy() + effort = efforts[idx]['observation.effort'].detach().cpu().numpy() + if(rospy.is_shutdown()): + break + + new_actions = np.linspace(last_action, action, 5) # 插值 + new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值 + new_efforts = np.linspace(last_effort, effort, 5) # 插值 + last_action = action + last_velocity = velocity + last_effort = effort + for act in new_actions: + print(np.round(act[:7], 4)) + cur_timestamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.header.stamp = cur_timestamp + + joint_state_msg.position = act[:7] + joint_state_msg.velocity = last_velocity[:7] + joint_state_msg.effort = last_effort[:7] + master_arm_left_publisher.publish(joint_state_msg) + + joint_state_msg.position = act[7:] + joint_state_msg.velocity = last_velocity[:7] + joint_state_msg.effort = last_effort[7:] + master_arm_right_publisher.publish(joint_state_msg) + + if(rospy.is_shutdown()): + break + rate.sleep() + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', + # default='/master/joint_left', required=False) + # parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', + # default='/master/joint_right', required=False) + + + args = parser.parse_args() + args.repo_id = "tangger/test" + args.root = "/home/ubuntu/LYT/aloha_lerobot/data1" + args.episode = 1 # replay episode + args.master_arm_left_topic = "/master/joint_left" + args.master_arm_right_topic = "/master/joint_right" + args.fps = 30 + + main(args) + # python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0 \ No newline at end of file diff --git a/lerobot_aloha/test.py b/lerobot_aloha/test.py new file mode 100644 index 0000000..8eb8748 --- /dev/null +++ b/lerobot_aloha/test.py @@ -0,0 +1,70 @@ +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.robot_devices.utils import busy_wait +import time +import argparse +from agilex_robot import AgilexRobot +import torch + +def get_arguments(): + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "tangger/test" + args.root = "./data2" + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 50 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "test test" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy= None + args.teleoprate = False + return args + + +cfg = get_arguments() +robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) +inference_time_s = 360 +fps = 30 +device = "cuda" # TODO: On Mac, use "mps" or "cpu" + +ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model" +policy = ACTPolicy.from_pretrained(ckpt_path) +policy.to(device) + +for _ in range(inference_time_s * fps): + start_time = time.perf_counter() + + # Read the follower state and access the frames from the cameras + observation = robot.capture_observation() + if observation is None: + print("Observation is None, skipping...") + continue + + # Convert to pytorch format: channel first and float32 in [0,1] + # with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + # Remove batch dimension + action = action.squeeze(0) + # Move to cpu, if not already the case + action = action.to("cpu") + # Order the robot to move + robot.send_action(action) + + dt_s = time.perf_counter() - start_time + busy_wait(1 / fps - dt_s) \ No newline at end of file