init commit
This commit is contained in:
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
__pycache__
|
||||||
|
build
|
||||||
|
debug
|
||||||
|
dist
|
||||||
|
*.egg
|
||||||
|
.vscode
|
||||||
|
output
|
||||||
|
*.err
|
||||||
|
*.out
|
||||||
|
.cursor
|
||||||
|
assets
|
||||||
|
curobo
|
||||||
|
panda_drake
|
||||||
|
tests/*.log
|
||||||
|
tests/*.txt
|
||||||
|
polygons.png
|
||||||
51
.pre-commit-config.yaml
Normal file
51
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
exclude: ^doc/
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v3.1.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb 100']
|
||||||
|
exclude: 'workflows/simbox/tools/grasp/example/.*|workflows/simbox/tools/grasp/pyarmor_runtime_000000/.*|workflows/simbox/example_assets/.*|workflows/simbox/tools/rigid_obj/example/.*'
|
||||||
|
- id: check-json
|
||||||
|
- id: check-docstring-first
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: mixed-line-ending
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: '5.12.0'
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
name: isort
|
||||||
|
files: "\\.(py)$"
|
||||||
|
args:
|
||||||
|
- --profile=black
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: '22.3.0'
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
args:
|
||||||
|
- --line-length=120
|
||||||
|
- --preview
|
||||||
|
- repo: https://github.com/pycqa/flake8
|
||||||
|
rev: '3.9.2'
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args:
|
||||||
|
- --max-line-length=120
|
||||||
|
- --ignore=E203,W503
|
||||||
|
- repo: https://github.com/PyCQA/pylint/
|
||||||
|
rev: 'v2.15.0'
|
||||||
|
hooks:
|
||||||
|
- id: pylint
|
||||||
|
name: pylint
|
||||||
|
entry: pylint
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
args:
|
||||||
|
[
|
||||||
|
'--rcfile=.pylintrc',
|
||||||
|
'--disable=C0103,C0114,C0415,W0212,W0235,W0238'
|
||||||
|
]
|
||||||
435
.pylintrc
Normal file
435
.pylintrc
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||||
|
# best-practices and style described in the Google Python style guide:
|
||||||
|
# https://google.github.io/styleguide/pyguide.html
|
||||||
|
#
|
||||||
|
# Its canonical open-source location is:
|
||||||
|
# https://google.github.io/styleguide/pylintrc
|
||||||
|
|
||||||
|
[MASTER]
|
||||||
|
|
||||||
|
# Files or directories to be skipped. They should be base names, not paths.
|
||||||
|
ignore=third_party
|
||||||
|
|
||||||
|
# Files or directories matching the regex patterns are skipped. The regex
|
||||||
|
# matches against base names, not paths.
|
||||||
|
ignore-patterns=
|
||||||
|
|
||||||
|
# Pickle collected data for later comparisons.
|
||||||
|
persistent=no
|
||||||
|
|
||||||
|
# List of plugins (as comma separated values of python modules names) to load,
|
||||||
|
# usually to register additional checkers.
|
||||||
|
load-plugins=
|
||||||
|
|
||||||
|
# Use multiple processes to speed up Pylint.
|
||||||
|
jobs=4
|
||||||
|
|
||||||
|
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||||
|
# active Python interpreter and may run arbitrary code.
|
||||||
|
unsafe-load-any-extension=no
|
||||||
|
|
||||||
|
|
||||||
|
[MESSAGES CONTROL]
|
||||||
|
|
||||||
|
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||||
|
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||||
|
confidence=
|
||||||
|
|
||||||
|
# Enable the message, report, category or checker with the given id(s). You can
|
||||||
|
# either give multiple identifier separated by comma (,) or put this option
|
||||||
|
# multiple time (only on the command line, not in the configuration file where
|
||||||
|
# it should appear only once). See also the "--disable" option for examples.
|
||||||
|
#enable=
|
||||||
|
|
||||||
|
# Disable the message, report, category or checker with the given id(s). You
|
||||||
|
# can either give multiple identifiers separated by comma (,) or put this
|
||||||
|
# option multiple times (only on the command line, not in the configuration
|
||||||
|
# file where it should appear only once).You can also use "--disable=all" to
|
||||||
|
# disable everything first and then reenable specific checks. For example, if
|
||||||
|
# you want to run only the similarities checker, you can use "--disable=all
|
||||||
|
# --enable=similarities". If you want to run only the classes checker, but have
|
||||||
|
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||||
|
# --disable=W"
|
||||||
|
disable=abstract-method,
|
||||||
|
apply-builtin,
|
||||||
|
arguments-differ,
|
||||||
|
attribute-defined-outside-init,
|
||||||
|
backtick,
|
||||||
|
bad-option-value,
|
||||||
|
basestring-builtin,
|
||||||
|
buffer-builtin,
|
||||||
|
c-extension-no-member,
|
||||||
|
consider-using-enumerate,
|
||||||
|
cmp-builtin,
|
||||||
|
cmp-method,
|
||||||
|
coerce-builtin,
|
||||||
|
coerce-method,
|
||||||
|
delslice-method,
|
||||||
|
div-method,
|
||||||
|
duplicate-code,
|
||||||
|
eq-without-hash,
|
||||||
|
execfile-builtin,
|
||||||
|
file-builtin,
|
||||||
|
filter-builtin-not-iterating,
|
||||||
|
fixme,
|
||||||
|
getslice-method,
|
||||||
|
global-statement,
|
||||||
|
hex-method,
|
||||||
|
idiv-method,
|
||||||
|
implicit-str-concat,
|
||||||
|
import-error,
|
||||||
|
import-self,
|
||||||
|
import-star-module-level,
|
||||||
|
inconsistent-return-statements,
|
||||||
|
input-builtin,
|
||||||
|
intern-builtin,
|
||||||
|
invalid-str-codec,
|
||||||
|
locally-disabled,
|
||||||
|
long-builtin,
|
||||||
|
long-suffix,
|
||||||
|
map-builtin-not-iterating,
|
||||||
|
misplaced-comparison-constant,
|
||||||
|
missing-function-docstring,
|
||||||
|
metaclass-assignment,
|
||||||
|
next-method-called,
|
||||||
|
next-method-defined,
|
||||||
|
no-absolute-import,
|
||||||
|
no-else-break,
|
||||||
|
no-else-continue,
|
||||||
|
no-else-raise,
|
||||||
|
no-else-return,
|
||||||
|
no-init, # added
|
||||||
|
no-member,
|
||||||
|
no-name-in-module,
|
||||||
|
no-self-use,
|
||||||
|
nonzero-method,
|
||||||
|
oct-method,
|
||||||
|
old-division,
|
||||||
|
old-ne-operator,
|
||||||
|
old-octal-literal,
|
||||||
|
old-raise-syntax,
|
||||||
|
parameter-unpacking,
|
||||||
|
print-statement,
|
||||||
|
raising-string,
|
||||||
|
range-builtin-not-iterating,
|
||||||
|
raw_input-builtin,
|
||||||
|
rdiv-method,
|
||||||
|
reduce-builtin,
|
||||||
|
relative-import,
|
||||||
|
reload-builtin,
|
||||||
|
round-builtin,
|
||||||
|
setslice-method,
|
||||||
|
signature-differs,
|
||||||
|
standarderror-builtin,
|
||||||
|
suppressed-message,
|
||||||
|
sys-max-int,
|
||||||
|
too-few-public-methods,
|
||||||
|
too-many-ancestors,
|
||||||
|
too-many-arguments,
|
||||||
|
too-many-boolean-expressions,
|
||||||
|
too-many-branches,
|
||||||
|
too-many-instance-attributes,
|
||||||
|
too-many-locals,
|
||||||
|
too-many-nested-blocks,
|
||||||
|
too-many-public-methods,
|
||||||
|
too-many-return-statements,
|
||||||
|
too-many-statements,
|
||||||
|
trailing-newlines,
|
||||||
|
unichr-builtin,
|
||||||
|
unicode-builtin,
|
||||||
|
unnecessary-pass,
|
||||||
|
unpacking-in-except,
|
||||||
|
useless-else-on-loop,
|
||||||
|
useless-object-inheritance,
|
||||||
|
useless-suppression,
|
||||||
|
using-cmp-argument,
|
||||||
|
wrong-import-order,
|
||||||
|
xrange-builtin,
|
||||||
|
zip-builtin-not-iterating,
|
||||||
|
R0917, # too-many-positional-arguments
|
||||||
|
W1203, # Use lazy % formatting in logging functions
|
||||||
|
W0707, # Consider explicitly re-raising using 'except StopIteration as exc'
|
||||||
|
C0115, # missing-class-docstring
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[REPORTS]
|
||||||
|
|
||||||
|
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||||
|
# (visual studio) and html. You can also give a reporter class, eg
|
||||||
|
# mypackage.mymodule.MyReporterClass.
|
||||||
|
output-format=colorized
|
||||||
|
|
||||||
|
# Tells whether to display a full report or only the messages
|
||||||
|
reports=no
|
||||||
|
|
||||||
|
# Python expression which should return a note less than 10 (10 is the highest
|
||||||
|
# note). You have access to the variables errors warning, statement which
|
||||||
|
# respectively contain the number of errors / warnings messages and the total
|
||||||
|
# number of statements analyzed. This is used by the global evaluation report
|
||||||
|
# (RP0004).
|
||||||
|
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||||
|
|
||||||
|
# Template used to display messages. This is a python new-style format string
|
||||||
|
# used to format the message information. See doc for all details
|
||||||
|
#msg-template=
|
||||||
|
|
||||||
|
|
||||||
|
[BASIC]
|
||||||
|
|
||||||
|
# Good variable names which should always be accepted, separated by a comma
|
||||||
|
good-names=main,_
|
||||||
|
|
||||||
|
# Bad variable names which should always be refused, separated by a comma
|
||||||
|
bad-names=
|
||||||
|
|
||||||
|
# Colon-delimited sets of names that determine each other's naming style when
|
||||||
|
# the name regexes allow several styles.
|
||||||
|
name-group=
|
||||||
|
|
||||||
|
# Include a hint for the correct naming format with invalid-name
|
||||||
|
include-naming-hint=no
|
||||||
|
|
||||||
|
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||||
|
# to this list to register other decorators that produce valid properties.
|
||||||
|
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||||
|
|
||||||
|
# Regular expression matching correct function names
|
||||||
|
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||||
|
|
||||||
|
# Regular expression matching correct variable names
|
||||||
|
variable-rgx=^[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct constant names
|
||||||
|
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||||
|
|
||||||
|
# Regular expression matching correct attribute names
|
||||||
|
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct argument names
|
||||||
|
argument-rgx=^[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct class attribute names
|
||||||
|
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||||
|
|
||||||
|
# Regular expression matching correct inline iteration names
|
||||||
|
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct class names
|
||||||
|
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct module names
|
||||||
|
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||||
|
|
||||||
|
# Regular expression matching correct method names
|
||||||
|
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||||
|
|
||||||
|
# Regular expression which should only match function or class names that do
|
||||||
|
# not require a docstring.
|
||||||
|
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||||
|
|
||||||
|
# Minimum line length for functions/classes that require docstrings, shorter
|
||||||
|
# ones are exempt.
|
||||||
|
docstring-min-length=10
|
||||||
|
|
||||||
|
|
||||||
|
[TYPECHECK]
|
||||||
|
|
||||||
|
# List of decorators that produce context managers, such as
|
||||||
|
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||||
|
# produce valid context managers.
|
||||||
|
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||||
|
|
||||||
|
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||||
|
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||||
|
ignore-mixin-members=yes
|
||||||
|
|
||||||
|
# List of module names for which member attributes should not be checked
|
||||||
|
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||||
|
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||||
|
# supports qualified module names, as well as Unix pattern matching.
|
||||||
|
ignored-modules=
|
||||||
|
|
||||||
|
# List of class names for which member attributes should not be checked (useful
|
||||||
|
# for classes with dynamically set attributes). This supports the use of
|
||||||
|
# qualified names.
|
||||||
|
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||||
|
|
||||||
|
# List of members which are set dynamically and missed by pylint inference
|
||||||
|
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||||
|
# expressions are accepted.
|
||||||
|
generated-members=
|
||||||
|
|
||||||
|
|
||||||
|
[FORMAT]
|
||||||
|
|
||||||
|
# Maximum number of characters on a single line.
|
||||||
|
max-line-length=120
|
||||||
|
|
||||||
|
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||||
|
# lines made too long by directives to pytype.
|
||||||
|
|
||||||
|
# Regexp for a line that is allowed to be longer than the limit.
|
||||||
|
ignore-long-lines=(?x)(
|
||||||
|
^\s*(\#\ )?<?https?://\S+>?$|
|
||||||
|
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||||
|
|
||||||
|
# Allow the body of an if to be on the same line as the test if there is no
|
||||||
|
# else.
|
||||||
|
single-line-if-stmt=yes
|
||||||
|
|
||||||
|
# Maximum number of lines in a module
|
||||||
|
max-module-lines=99999
|
||||||
|
|
||||||
|
# String used as indentation unit. The internal Google style guide mandates 2
|
||||||
|
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||||
|
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||||
|
# projects (like TensorFlow).
|
||||||
|
indent-string=' '
|
||||||
|
|
||||||
|
# Number of spaces of indent required inside a hanging or continued line.
|
||||||
|
indent-after-paren=4
|
||||||
|
|
||||||
|
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||||
|
expected-line-ending-format=
|
||||||
|
|
||||||
|
|
||||||
|
[MISCELLANEOUS]
|
||||||
|
|
||||||
|
# List of note tags to take in consideration, separated by a comma.
|
||||||
|
notes=TODO
|
||||||
|
|
||||||
|
|
||||||
|
[STRING]
|
||||||
|
|
||||||
|
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||||
|
# character used as a quote delimiter is used inconsistently within a module.
|
||||||
|
check-quote-consistency=yes
|
||||||
|
|
||||||
|
|
||||||
|
[VARIABLES]
|
||||||
|
|
||||||
|
# Tells whether we should check for unused import in __init__ files.
|
||||||
|
init-import=no
|
||||||
|
|
||||||
|
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||||
|
# not used).
|
||||||
|
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||||
|
|
||||||
|
# List of additional names supposed to be defined in builtins. Remember that
|
||||||
|
# you should avoid to define new builtins when possible.
|
||||||
|
additional-builtins=
|
||||||
|
|
||||||
|
# List of strings which can identify a callback function by name. A callback
|
||||||
|
# name must start or end with one of those strings.
|
||||||
|
callbacks=cb_,_cb
|
||||||
|
|
||||||
|
# List of qualified module names which can have objects that can redefine
|
||||||
|
# builtins.
|
||||||
|
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||||
|
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
|
||||||
|
# Logging modules to check that the string format arguments are in logging
|
||||||
|
# function parameter format
|
||||||
|
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||||
|
|
||||||
|
|
||||||
|
[SIMILARITIES]
|
||||||
|
|
||||||
|
# Minimum lines number of a similarity.
|
||||||
|
min-similarity-lines=4
|
||||||
|
|
||||||
|
# Ignore comments when computing similarities.
|
||||||
|
ignore-comments=yes
|
||||||
|
|
||||||
|
# Ignore docstrings when computing similarities.
|
||||||
|
ignore-docstrings=yes
|
||||||
|
|
||||||
|
# Ignore imports when computing similarities.
|
||||||
|
ignore-imports=no
|
||||||
|
|
||||||
|
|
||||||
|
[SPELLING]
|
||||||
|
|
||||||
|
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||||
|
# install python-enchant package.
|
||||||
|
spelling-dict=
|
||||||
|
|
||||||
|
# List of comma separated words that should not be checked.
|
||||||
|
spelling-ignore-words=
|
||||||
|
|
||||||
|
# A path to a file that contains private dictionary; one word per line.
|
||||||
|
spelling-private-dict-file=
|
||||||
|
|
||||||
|
# Tells whether to store unknown words to indicated private dictionary in
|
||||||
|
# --spelling-private-dict-file option instead of raising a message.
|
||||||
|
spelling-store-unknown-words=no
|
||||||
|
|
||||||
|
|
||||||
|
[IMPORTS]
|
||||||
|
|
||||||
|
# Deprecated modules which should not be used, separated by a comma
|
||||||
|
deprecated-modules=regsub,
|
||||||
|
TERMIOS,
|
||||||
|
Bastion,
|
||||||
|
rexec,
|
||||||
|
sets
|
||||||
|
|
||||||
|
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||||
|
# given file (report RP0402 must not be disabled)
|
||||||
|
import-graph=
|
||||||
|
|
||||||
|
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||||
|
# not be disabled)
|
||||||
|
ext-import-graph=
|
||||||
|
|
||||||
|
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||||
|
# not be disabled)
|
||||||
|
int-import-graph=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of the standard
|
||||||
|
# compatibility libraries.
|
||||||
|
known-standard-library=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of a third party library.
|
||||||
|
known-third-party=enchant, absl
|
||||||
|
|
||||||
|
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||||
|
# 3 compatible code, which means that the block might have code that exists
|
||||||
|
# only in one or another interpreter, leading to false positives when analysed.
|
||||||
|
analyse-fallback-blocks=no
|
||||||
|
|
||||||
|
|
||||||
|
[CLASSES]
|
||||||
|
|
||||||
|
# List of method names used to declare (i.e. assign) instance attributes.
|
||||||
|
defining-attr-methods=__init__,
|
||||||
|
__new__,
|
||||||
|
setUp
|
||||||
|
|
||||||
|
# List of member names, which should be excluded from the protected access
|
||||||
|
# warning.
|
||||||
|
exclude-protected=_asdict,
|
||||||
|
_fields,
|
||||||
|
_replace,
|
||||||
|
_source,
|
||||||
|
_make
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a class method.
|
||||||
|
valid-classmethod-first-arg=cls,
|
||||||
|
class_
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a metaclass class method.
|
||||||
|
valid-metaclass-classmethod-first-arg=mcs
|
||||||
|
|
||||||
|
|
||||||
|
[EXCEPTIONS]
|
||||||
|
|
||||||
|
# Exceptions that will emit a warning when being caught. Defaults to
|
||||||
|
# "Exception"
|
||||||
|
overgeneral-exceptions=StandardError,
|
||||||
|
Exception,
|
||||||
|
BaseException
|
||||||
57
README.md
Normal file
57
README.md
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
<div align="center">
|
||||||
|
|
||||||
|
# InternDataEngine: A simulation-based data generation engine designed for robotic learning.
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
[](https://arxiv.org/abs/2511.16651)
|
||||||
|
[](https://arxiv.org/abs/2601.21449)
|
||||||
|
[](https://arxiv.org/abs/2510.13778)
|
||||||
|
[](https://huggingface.co/datasets/InternRobotics/InternData-A1)
|
||||||
|
[](https://huggingface.co/datasets/InternRobotics/InternData-M1)
|
||||||
|
[](#)
|
||||||
|
|
||||||
|
## 💻 About
|
||||||
|
|
||||||
|
InternDataEngine is a data-centric engine for embodied AI that powers large-scale model training and iteration.
|
||||||
|
Built on NVIDIA Isaac Sim, it unifies high-fidelity physical interaction from InternData-A1, semantic task and scene generation from InternData-M1, and high-throughput scheduling from the Nimbus framework to deliver realistic, task-aligned, and massively scalable robotic manipulation data.
|
||||||
|
|
||||||
|
- **More realistic physical interaction**: Unified simulation of rigid, articulated, deformable, and fluid objects across single-arm, dual-arm, and humanoid robots, enabling long-horizon, skill-composed manipulation that better supports sim-to-real transfer.
|
||||||
|
- **More task-aligned data generation**: LLM-driven task and instruction generation with task-oriented scene graphs (ToSG), producing structured scenes and rich multi-modal annotations (boxes, keypoints, trajectories) for complex instruction-following and spatial reasoning.
|
||||||
|
- **More efficient large-scale production**: Nimbus-powered asynchronous pipelines that decouple planning, rendering, and storage, achieving 2–3× end-to-end throughput, cluster-level load balancing and fault tolerance for billion-scale data generation.
|
||||||
|
|
||||||
|
## 📢 Latest News 🔥
|
||||||
|
|
||||||
|
- **[2026/03]** We release the InternDataEngine codebase, which includes the core modules: InternData-A1, Nimbus, and InternData-M1.
|
||||||
|
|
||||||
|
## 🚀 Quickstart
|
||||||
|
|
||||||
|
Please refer to the [Installation](TBD) and [Usage](TBD) to start the installation and run your first synthetic data generation task.
|
||||||
|
|
||||||
|
For more details, please check [Documentation](TBD).
|
||||||
|
|
||||||
|
## License and Citation
|
||||||
|
All the code within this repo are under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). Please consider citing our papers if it helps your research.
|
||||||
|
|
||||||
|
```BibTeX
|
||||||
|
@article{tian2025interndata,
|
||||||
|
title={Interndata-a1: Pioneering high-fidelity synthetic data for pre-training generalist policy},
|
||||||
|
author={Tian, Yang and Yang, Yuyin and Xie, Yiman and Cai, Zetao and Shi, Xu and Gao, Ning and Liu, Hangxu and Jiang, Xuekun and Qiu, Zherui and Yuan, Feng and others},
|
||||||
|
journal={arXiv preprint arXiv:2511.16651},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{he2026nimbus,
|
||||||
|
title={Nimbus: A Unified Embodied Synthetic Data Generation Framework},
|
||||||
|
author={He, Zeyu and Zhang, Yuchang and Zhou, Yuanzhen and Tao, Miao and Li, Hengjie and Tian, Yang and Zeng, Jia and Wang, Tai and Cai, Wenzhe and Chen, Yilun and others},
|
||||||
|
journal={arXiv preprint arXiv:2601.21449},
|
||||||
|
year={2026}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{chen2025internvla,
|
||||||
|
title={Internvla-m1: A spatially guided vision-language-action framework for generalist robot policy},
|
||||||
|
author={Chen, Xinyi and Chen, Yilun and Fu, Yanwei and Gao, Ning and Jia, Jiaya and Jin, Weiyang and Li, Hao and Mu, Yao and Pang, Jiangmiao and Qiao, Yu and others},
|
||||||
|
journal={arXiv preprint arXiv:2510.13778},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
62
configs/simbox/de_pipe_template.yaml
Normal file
62
configs/simbox/de_pipe_template.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
name: simbox_pipe_dynamic_actor
|
||||||
|
load_stage:
|
||||||
|
scene_loader: # Scene loader (plan process)
|
||||||
|
type: env_loader
|
||||||
|
args:
|
||||||
|
workflow_type: SimBoxDualWorkFlow
|
||||||
|
cfg_path: workflows/simbox/core/configs/tasks/example/sort_the_rubbish.yaml # Task config path
|
||||||
|
simulator:
|
||||||
|
physics_dt: 1/30 # Physics update rate
|
||||||
|
rendering_dt: 1/30 # Render update rate
|
||||||
|
stage_units_in_meters: 1.0 # Stage unit scale
|
||||||
|
headless: True # Headless mode (no GUI); set false for visual debugging
|
||||||
|
anti_aliasing: 0 # Anti-aliasing level
|
||||||
|
layout_random_generator: # Scene randomization
|
||||||
|
type: env_randomizer
|
||||||
|
args:
|
||||||
|
random_num: 3 # Number of random samples per task
|
||||||
|
strict_mode: true # true: output count must equal random_num
|
||||||
|
plan_stage:
|
||||||
|
seq_planner: # Trajectory planner
|
||||||
|
type: env_planner
|
||||||
|
dump_stage:
|
||||||
|
dumper: # Serialize plan results for render process
|
||||||
|
type: env
|
||||||
|
dedump_stage:
|
||||||
|
dedumper: # Deserialize plan results in render process
|
||||||
|
type: de
|
||||||
|
scene_loader: # Scene loader (render process)
|
||||||
|
type: env_loader
|
||||||
|
args:
|
||||||
|
workflow_type: SimBoxDualWorkFlow
|
||||||
|
cfg_path: workflows/simbox/core/configs/tasks/example/sort_the_rubbish.yaml # Must match load_stage cfg_path
|
||||||
|
simulator:
|
||||||
|
physics_dt: 1/30 # Physics update rate
|
||||||
|
rendering_dt: 1/30 # Render update rate
|
||||||
|
rendering_interval: 5 # Render every N physics steps
|
||||||
|
headless: true # Headless mode
|
||||||
|
multi_gpu: true # Enable multi-GPU for render workers
|
||||||
|
layout_random_generator:
|
||||||
|
type: env_randomizer
|
||||||
|
seq_planner: # Reads serialized trajectories
|
||||||
|
type: env_reader
|
||||||
|
render_stage:
|
||||||
|
renderer: # Renderer
|
||||||
|
type: env_renderer
|
||||||
|
store_stage:
|
||||||
|
writer: # Data writer
|
||||||
|
type: env_writer
|
||||||
|
args:
|
||||||
|
batch_async: false # Sync writes (safer for pipe mode)
|
||||||
|
output_dir: output/${name}/ # Output directory
|
||||||
|
stage_pipe:
|
||||||
|
stage_num: [3, 3] # Number of pipeline stages [plan, render]
|
||||||
|
stage_dev: ["gpu", "gpu"] # Device type per stage
|
||||||
|
worker_num: [2, 1] # Number of workers per stage [plan, render]
|
||||||
|
worker_schedule: True # Enable dynamic worker scheduling
|
||||||
|
safe_threshold: 100 # Max pending items before throttling
|
||||||
|
status_timeouts: # Worker timeout thresholds (seconds, -1 = infinite)
|
||||||
|
idle: 360 # Max idle time before restart
|
||||||
|
ready: -1 # Max time in ready state
|
||||||
|
running: 600 # Max running time per task
|
||||||
|
monitor_check_interval: 120 # Health check interval (seconds)
|
||||||
30
configs/simbox/de_plan_and_render_template.yaml
Normal file
30
configs/simbox/de_plan_and_render_template.yaml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: simbox_plan_and_render
|
||||||
|
load_stage:
|
||||||
|
scene_loader: # Scene loader
|
||||||
|
type: env_loader
|
||||||
|
args:
|
||||||
|
workflow_type: SimBoxDualWorkFlow
|
||||||
|
cfg_path: workflows/simbox/core/configs/tasks/example/sort_the_rubbish.yaml # Task config path
|
||||||
|
simulator:
|
||||||
|
physics_dt: 1/30 # Physics update rate
|
||||||
|
rendering_dt: 1/30 # Render update rate
|
||||||
|
stage_units_in_meters: 1.0 # Stage unit scale
|
||||||
|
headless: True # Headless mode (no GUI); set false for visual debugging
|
||||||
|
anti_aliasing: 0 # Anti-aliasing level
|
||||||
|
layout_random_generator: # Scene randomization
|
||||||
|
type: env_randomizer
|
||||||
|
args:
|
||||||
|
random_num: 6 # Number of random samples per task
|
||||||
|
strict_mode: true # true: output count must equal random_num
|
||||||
|
plan_stage:
|
||||||
|
seq_planner: # Trajectory planner
|
||||||
|
type: env_planner
|
||||||
|
render_stage:
|
||||||
|
renderer: # Renderer
|
||||||
|
type: env_renderer
|
||||||
|
store_stage:
|
||||||
|
writer: # Data writer
|
||||||
|
type: env_writer
|
||||||
|
args:
|
||||||
|
batch_async: true # Async writes (better perf, more memory)
|
||||||
|
output_dir: output/${name}/ # Output directory
|
||||||
27
configs/simbox/de_plan_template.yaml
Normal file
27
configs/simbox/de_plan_template.yaml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: simbox_plan
|
||||||
|
load_stage:
|
||||||
|
scene_loader: # Scene loader
|
||||||
|
type: env_loader
|
||||||
|
args:
|
||||||
|
workflow_type: SimBoxDualWorkFlow
|
||||||
|
cfg_path: workflows/simbox/core/configs/tasks/example/sort_the_rubbish.yaml # Task config path
|
||||||
|
simulator:
|
||||||
|
physics_dt: 1/30 # Physics update rate
|
||||||
|
rendering_dt: 1/30 # Render update rate
|
||||||
|
stage_units_in_meters: 1.0 # Stage unit scale
|
||||||
|
headless: True # Headless mode (no GUI)
|
||||||
|
anti_aliasing: 0 # Anti-aliasing level
|
||||||
|
layout_random_generator: # Scene randomization
|
||||||
|
type: env_randomizer
|
||||||
|
args:
|
||||||
|
random_num: 6 # Number of random samples per task
|
||||||
|
strict_mode: true # true: output count must equal random_num
|
||||||
|
plan_stage:
|
||||||
|
seq_planner: # Trajectory planner (plan only, no rendering)
|
||||||
|
type: env_planner
|
||||||
|
store_stage:
|
||||||
|
writer: # Data writer
|
||||||
|
type: env_writer
|
||||||
|
args:
|
||||||
|
batch_async: true # Async writes (better perf, more memory)
|
||||||
|
seq_output_dir: output/${name}/ # Trajectory output directory
|
||||||
27
configs/simbox/de_plan_with_render_template.yaml
Normal file
27
configs/simbox/de_plan_with_render_template.yaml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: simbox_plan_with_render
|
||||||
|
load_stage:
|
||||||
|
scene_loader: # Scene loader
|
||||||
|
type: env_loader
|
||||||
|
args:
|
||||||
|
workflow_type: SimBoxDualWorkFlow
|
||||||
|
cfg_path: workflows/simbox/core/configs/tasks/example/sort_the_rubbish.yaml # Task config path
|
||||||
|
simulator:
|
||||||
|
physics_dt: 1/30 # Physics update rate
|
||||||
|
rendering_dt: 1/30 # Render update rate
|
||||||
|
stage_units_in_meters: 1.0 # Stage unit scale
|
||||||
|
headless: True # Headless mode (no GUI); set false for visual debugging
|
||||||
|
anti_aliasing: 0 # Anti-aliasing level
|
||||||
|
layout_random_generator: # Scene randomization
|
||||||
|
type: env_randomizer
|
||||||
|
args:
|
||||||
|
random_num: 6 # Number of random samples per task
|
||||||
|
strict_mode: true # true: output count must equal random_num
|
||||||
|
plan_with_render_stage:
|
||||||
|
plan_with_render: # Plan and render in a single stage
|
||||||
|
type: plan_with_render
|
||||||
|
store_stage:
|
||||||
|
writer: # Data writer
|
||||||
|
type: env_writer
|
||||||
|
args:
|
||||||
|
batch_async: true # Async writes (better perf, more memory)
|
||||||
|
output_dir: output/${name}/ # Output directory
|
||||||
29
configs/simbox/de_render_template.yaml
Normal file
29
configs/simbox/de_render_template.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: simbox_render
|
||||||
|
load_stage:
|
||||||
|
scene_loader: # Scene loader
|
||||||
|
type: env_loader
|
||||||
|
args:
|
||||||
|
workflow_type: SimBoxDualWorkFlow
|
||||||
|
cfg_path: workflows/simbox/core/configs/tasks/example/sort_the_rubbish.yaml # Task config path
|
||||||
|
simulator:
|
||||||
|
physics_dt: 1/30 # Physics update rate
|
||||||
|
rendering_dt: 1/30 # Render update rate
|
||||||
|
stage_units_in_meters: 1.0 # Stage unit scale
|
||||||
|
headless: True # Headless mode (no GUI)
|
||||||
|
anti_aliasing: 0 # Anti-aliasing level
|
||||||
|
layout_random_generator: # Reads pre-planned trajectories from input_dir
|
||||||
|
type: env_randomizer
|
||||||
|
args:
|
||||||
|
input_dir: output/simbox_plan/BananaBaseTask/plan # Path to plan output from de_plan_template
|
||||||
|
plan_stage:
|
||||||
|
seq_planner: # Reads serialized trajectories (render only, no planning)
|
||||||
|
type: env_reader
|
||||||
|
render_stage:
|
||||||
|
renderer: # Renderer
|
||||||
|
type: env_renderer
|
||||||
|
store_stage:
|
||||||
|
writer: # Data writer
|
||||||
|
type: env_writer
|
||||||
|
args:
|
||||||
|
batch_async: true # Async writes (better perf, more memory)
|
||||||
|
output_dir: output/${name}/ # Output directory
|
||||||
283
deps/world_toolkit/world_recorder/__init__.py
vendored
Normal file
283
deps/world_toolkit/world_recorder/__init__.py
vendored
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from omni.isaac.core.articulations.articulation import Articulation
|
||||||
|
from omni.isaac.core.robots.robot import Robot
|
||||||
|
from omni.isaac.core.utils.numpy.transformations import get_local_from_world
|
||||||
|
from omni.isaac.core.utils.prims import get_prim_at_path, get_prim_parent
|
||||||
|
from omni.isaac.core.utils.xforms import get_world_pose
|
||||||
|
from pxr import Gf, Usd, UsdGeom
|
||||||
|
|
||||||
|
from workflows.utils.utils import get_link
|
||||||
|
|
||||||
|
|
||||||
|
class WorldRecorder:
|
||||||
|
"""
|
||||||
|
WorldRecorder handles recording and replaying simulation states.
|
||||||
|
|
||||||
|
Two modes are supported:
|
||||||
|
- step_replay=False: Records prim poses for fast geometric replay
|
||||||
|
- step_replay=True: Records robot joint positions and object world poses
|
||||||
|
for physics-accurate replay via set_joint_positions / set_world_pose
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, world, robots: list, objs: list, step_replay: bool = False):
|
||||||
|
self.world = world
|
||||||
|
self.stage = world.stage
|
||||||
|
self.xform_prims = []
|
||||||
|
self.prim_poses = []
|
||||||
|
self.prim_visibilities = []
|
||||||
|
self.replay_counter = 0
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
|
for robot_name, robot in robots.items():
|
||||||
|
if not isinstance(robot, Robot):
|
||||||
|
raise TypeError(
|
||||||
|
f"Robot '{robot_name}' must be an instance of omni.isaac.core.robots.robot.Robot "
|
||||||
|
f"or its subclass, got {type(robot).__name__} instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.robots = robots
|
||||||
|
self.objs = objs
|
||||||
|
self.step_replay = step_replay
|
||||||
|
|
||||||
|
self._initialize_xform_prims()
|
||||||
|
|
||||||
|
if self.step_replay:
|
||||||
|
print("use joint_position to replay, WorldRecorder will record joint positions and object poses")
|
||||||
|
self.robot_joint_data = {name: [] for name in robots}
|
||||||
|
self.object_state_data = {name: [] for name in objs}
|
||||||
|
else:
|
||||||
|
print("use prim poses to replay, WorldRecorder will record prim poses for replay")
|
||||||
|
|
||||||
|
def _initialize_xform_prims(self):
|
||||||
|
self.robots_prim = []
|
||||||
|
robots_prim_paths = [robot.prim_path for robot in self.robots.values()]
|
||||||
|
for robot_prim_path in robots_prim_paths:
|
||||||
|
robot_prim_path = get_prim_at_path(robot_prim_path)
|
||||||
|
link_dict = get_link(robot_prim_path)
|
||||||
|
robots_paths = list(link_dict.values())
|
||||||
|
print(robots_paths)
|
||||||
|
self.robots_prim.extend(robots_paths)
|
||||||
|
|
||||||
|
self.objects_prim = []
|
||||||
|
record_objects = self.objs
|
||||||
|
for _, obj in record_objects.items():
|
||||||
|
object_prim_path = obj.prim
|
||||||
|
if isinstance(obj, Articulation):
|
||||||
|
link_dict = get_link(object_prim_path)
|
||||||
|
object_paths = list(link_dict.values())
|
||||||
|
self.objects_prim.extend(object_paths)
|
||||||
|
else:
|
||||||
|
self.objects_prim.append(object_prim_path)
|
||||||
|
self.xform_prims.extend(self.robots_prim)
|
||||||
|
self.xform_prims.extend(self.objects_prim)
|
||||||
|
print(f"Found {len(self.xform_prims)} xformable prims")
|
||||||
|
|
||||||
|
def record(self):
|
||||||
|
"""Record current frame state."""
|
||||||
|
if self.step_replay:
|
||||||
|
for robot_name, robot in self.robots.items():
|
||||||
|
joint_positions = robot.get_joint_positions()
|
||||||
|
self.robot_joint_data[robot_name].append(joint_positions)
|
||||||
|
|
||||||
|
for obj_name, obj in self.objs.items():
|
||||||
|
translation, orientation = obj.get_world_pose()
|
||||||
|
state = {
|
||||||
|
"translation": translation,
|
||||||
|
"orientation": orientation,
|
||||||
|
}
|
||||||
|
if isinstance(obj, Articulation):
|
||||||
|
state["joint_positions"] = obj.get_joint_positions()
|
||||||
|
self.object_state_data[obj_name].append(state)
|
||||||
|
|
||||||
|
frame_visibilities = []
|
||||||
|
for prim in self.xform_prims:
|
||||||
|
visibility = prim.GetAttribute("visibility").Get()
|
||||||
|
frame_visibilities.append(visibility)
|
||||||
|
self.prim_visibilities.append(frame_visibilities)
|
||||||
|
|
||||||
|
self.num_steps += 1
|
||||||
|
else:
|
||||||
|
frame_poses = []
|
||||||
|
frame_visibilities = []
|
||||||
|
for prim in self.xform_prims:
|
||||||
|
world_pose = get_world_pose(prim.GetPath().pathString)
|
||||||
|
frame_poses.append([world_pose[0].tolist(), world_pose[1].tolist()])
|
||||||
|
|
||||||
|
visibility = prim.GetAttribute("visibility").Get()
|
||||||
|
frame_visibilities.append(visibility)
|
||||||
|
|
||||||
|
self.prim_poses.append(frame_poses)
|
||||||
|
self.prim_visibilities.append(frame_visibilities)
|
||||||
|
self.num_steps += 1
|
||||||
|
|
||||||
|
def warmup(self):
|
||||||
|
"""Internal warmup logic for different modes."""
|
||||||
|
if self.step_replay:
|
||||||
|
print("Warming up simulation (joint_position mode)...")
|
||||||
|
if self.num_steps > 0:
|
||||||
|
self._replay_from_joint_positions(increment_counter=False)
|
||||||
|
for _ in range(10):
|
||||||
|
self.world.step(render=True)
|
||||||
|
self.world.get_observations()
|
||||||
|
print("Warmup completed. Starting replay...")
|
||||||
|
else:
|
||||||
|
print("Warming up (prim poses mode)...")
|
||||||
|
if self.num_steps > 0:
|
||||||
|
self._replay_from_prim_poses(increment_counter=False)
|
||||||
|
for _ in range(10):
|
||||||
|
self.world.render()
|
||||||
|
self.world.get_observations()
|
||||||
|
print("Warmup completed. Starting replay...")
|
||||||
|
|
||||||
|
def get_total_steps(self):
|
||||||
|
return self.num_steps
|
||||||
|
|
||||||
|
def replay(self):
|
||||||
|
"""
|
||||||
|
Unified replay interface. Automatically selects the appropriate replay method
|
||||||
|
based on mode setting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if replay is complete, False otherwise
|
||||||
|
"""
|
||||||
|
if self.step_replay:
|
||||||
|
return self._replay_from_joint_positions()
|
||||||
|
else:
|
||||||
|
return self._replay_from_prim_poses()
|
||||||
|
|
||||||
|
def _replay_from_joint_positions(self, increment_counter: bool = True):
|
||||||
|
"""
|
||||||
|
Replay from recorded joint position / world pose data.
|
||||||
|
Uses world.step(render=True) for proper physics and joint constraints.
|
||||||
|
"""
|
||||||
|
if self.num_steps == 0:
|
||||||
|
print("No steps to replay")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.replay_counter == 0:
|
||||||
|
print(f"Starting replay of {self.num_steps} steps from joint position data...")
|
||||||
|
|
||||||
|
if self.replay_counter < self.num_steps:
|
||||||
|
self._apply_recorded_states()
|
||||||
|
self.world.step(render=True)
|
||||||
|
if increment_counter:
|
||||||
|
self.replay_counter += 1
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("Replay complete")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _replay_from_prim_poses(self, increment_counter: bool = True):
|
||||||
|
if self.replay_counter == 0:
|
||||||
|
print(f"Re-found {len(self.xform_prims)} xformable prims")
|
||||||
|
if self.replay_counter < self.num_steps:
|
||||||
|
frame_poses = self.prim_poses[self.replay_counter]
|
||||||
|
frame_visibilities = self.prim_visibilities[self.replay_counter]
|
||||||
|
for prim, world_pose, frame_visibility in zip(self.xform_prims, frame_poses, frame_visibilities):
|
||||||
|
parent_transforms = np.array(
|
||||||
|
[UsdGeom.Xformable(get_prim_parent(prim)).ComputeLocalToWorldTransform(Usd.TimeCode.Default())]
|
||||||
|
)
|
||||||
|
translations, orientations = get_local_from_world(
|
||||||
|
parent_transforms, np.array([world_pose[0]]), np.array([world_pose[1]])
|
||||||
|
)
|
||||||
|
|
||||||
|
properties = prim.GetPropertyNames()
|
||||||
|
translation = Gf.Vec3d(*translations[0].tolist())
|
||||||
|
if "xformOp:translate" in properties:
|
||||||
|
xform_op = prim.GetAttribute("xformOp:translate")
|
||||||
|
xform_op.Set(translation)
|
||||||
|
|
||||||
|
if "xformOp:orient" in properties:
|
||||||
|
xform_op = prim.GetAttribute("xformOp:orient")
|
||||||
|
if xform_op.GetTypeName() == "quatf":
|
||||||
|
rotq = Gf.Quatf(*orientations[0].tolist())
|
||||||
|
else:
|
||||||
|
rotq = Gf.Quatd(*orientations[0].tolist())
|
||||||
|
xform_op.Set(rotq)
|
||||||
|
|
||||||
|
if frame_visibility == UsdGeom.Tokens.invisible:
|
||||||
|
prim.GetAttribute("visibility").Set("invisible")
|
||||||
|
else:
|
||||||
|
prim.GetAttribute("visibility").Set("inherited")
|
||||||
|
|
||||||
|
self.world.render()
|
||||||
|
if increment_counter:
|
||||||
|
self.replay_counter += 1
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("Replay complete")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _apply_recorded_states(self):
|
||||||
|
"""Apply recorded robot joint positions, object world poses, and visibilities for the current step."""
|
||||||
|
step = self.replay_counter
|
||||||
|
|
||||||
|
for robot_name, robot in self.robots.items():
|
||||||
|
joint_positions = self.robot_joint_data[robot_name][step]
|
||||||
|
robot.set_joint_positions(positions=joint_positions)
|
||||||
|
|
||||||
|
for obj_name, obj in self.objs.items():
|
||||||
|
state = self.object_state_data[obj_name][step]
|
||||||
|
obj.set_world_pose(state["translation"], state["orientation"])
|
||||||
|
|
||||||
|
if "joint_positions" in state and state["joint_positions"] is not None:
|
||||||
|
obj.set_joint_positions(state["joint_positions"])
|
||||||
|
|
||||||
|
frame_visibilities = self.prim_visibilities[step]
|
||||||
|
for prim, frame_visibility in zip(self.xform_prims, frame_visibilities):
|
||||||
|
if frame_visibility == UsdGeom.Tokens.invisible:
|
||||||
|
prim.GetAttribute("visibility").Set("invisible")
|
||||||
|
else:
|
||||||
|
prim.GetAttribute("visibility").Set("inherited")
|
||||||
|
|
||||||
|
def dumps(self):
|
||||||
|
"""Serialize recorder data based on mode."""
|
||||||
|
if self.step_replay:
|
||||||
|
record_data = {
|
||||||
|
"mode": "joint_position",
|
||||||
|
"num_steps": self.num_steps,
|
||||||
|
"robot_joint_data": self.robot_joint_data,
|
||||||
|
"object_state_data": self.object_state_data,
|
||||||
|
"prim_visibilities": self.prim_visibilities,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
record_data = {
|
||||||
|
"mode": "prim_pose",
|
||||||
|
"num_steps": self.num_steps,
|
||||||
|
"prim_poses": self.prim_poses,
|
||||||
|
"prim_visibilities": self.prim_visibilities,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pickle.dumps(record_data)
|
||||||
|
|
||||||
|
def loads(self, data):
|
||||||
|
"""Deserialize recorder data based on mode."""
|
||||||
|
record_data = pickle.loads(data)
|
||||||
|
mode = record_data.get("mode", "prim_pose")
|
||||||
|
|
||||||
|
if mode == "prim_pose" and not self.step_replay:
|
||||||
|
self.num_steps = record_data["num_steps"]
|
||||||
|
self.prim_poses = record_data["prim_poses"]
|
||||||
|
self.prim_visibilities = record_data["prim_visibilities"]
|
||||||
|
elif mode == "joint_position" and self.step_replay:
|
||||||
|
self.num_steps = record_data["num_steps"]
|
||||||
|
self.robot_joint_data = record_data["robot_joint_data"]
|
||||||
|
self.object_state_data = record_data["object_state_data"]
|
||||||
|
self.prim_visibilities = record_data["prim_visibilities"]
|
||||||
|
else:
|
||||||
|
mode_name = "prim_pose" if not self.step_replay else "joint_position"
|
||||||
|
raise ValueError(f"Mode mismatch: data is '{mode}', recorder is '{mode_name}'")
|
||||||
|
|
||||||
|
return record_data
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.num_steps = 0
|
||||||
|
self.replay_counter = 0
|
||||||
|
self.prim_poses = []
|
||||||
|
self.prim_visibilities = []
|
||||||
|
if self.step_replay:
|
||||||
|
self.robot_joint_data = {name: [] for name in self.robots}
|
||||||
|
self.object_state_data = {name: [] for name in self.objs}
|
||||||
|
print("WorldRecorder reset")
|
||||||
49
launcher.py
Normal file
49
launcher.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# pylint: disable=C0413
|
||||||
|
# flake8: noqa: E402
|
||||||
|
|
||||||
|
from nimbus.utils.utils import init_env
|
||||||
|
|
||||||
|
init_env()
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from nimbus import run_data_engine
|
||||||
|
from nimbus.utils.config_processor import ConfigProcessor
|
||||||
|
from nimbus.utils.flags import set_debug_mode, set_random_seed
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", required=True, help="path to config file")
|
||||||
|
parser.add_argument("--random_seed", help="random seed")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="enable debug mode: all errors raised immediately")
|
||||||
|
args, extras = parser.parse_known_args()
|
||||||
|
|
||||||
|
processor = ConfigProcessor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = processor.process_config(args.config, cli_args=extras)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Configuration Error: {e}")
|
||||||
|
print(f"\n Available configuration paths can be found in: {args.config}")
|
||||||
|
print(" Use dot notation to override nested values, e.g.:")
|
||||||
|
print(" --stage_pipe.worker_num='[2,4]'")
|
||||||
|
print(" --load_stage.layout_random_generator.args.random_num=500")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
processor.print_final_config(config)
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
set_debug_mode(True)
|
||||||
|
|
||||||
|
if args.random_seed is not None:
|
||||||
|
set_random_seed(int(args.random_seed))
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_data_engine(config, args.random_seed)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
16
nimbus/__init__.py
Normal file
16
nimbus/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import ray
|
||||||
|
|
||||||
|
from nimbus.utils.types import STAGE_PIPE
|
||||||
|
|
||||||
|
from .data_engine import DataEngine, DistPipeDataEngine
|
||||||
|
|
||||||
|
|
||||||
|
def run_data_engine(config, master_seed=None):
|
||||||
|
import nimbus_extension # noqa: F401 pylint: disable=unused-import
|
||||||
|
|
||||||
|
if STAGE_PIPE in config:
|
||||||
|
ray.init(num_gpus=1)
|
||||||
|
data_engine = DistPipeDataEngine(config, master_seed=master_seed)
|
||||||
|
else:
|
||||||
|
data_engine = DataEngine(config, master_seed=master_seed)
|
||||||
|
data_engine.run()
|
||||||
0
nimbus/components/data/__init__.py
Normal file
0
nimbus/components/data/__init__.py
Normal file
71
nimbus/components/data/camera.py
Normal file
71
nimbus/components/data/camera.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class C2W:
|
||||||
|
"""
|
||||||
|
Represents a camera-to-world transformation matrix.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
matrix (List[float]): A list of 16 floats representing the 4x4 transformation matrix in row-major order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
matrix: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Camera:
|
||||||
|
"""
|
||||||
|
Represents a single camera pose in the trajectory.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
trajectory (List[C2W]): List of C2W transformations for this camera pose.
|
||||||
|
intrinsic (Optional[List[float]]): 3x3 camera intrinsic matrix: [[fx, 0, cx], [0, fy, cy], [0, 0, 1]].
|
||||||
|
extrinsic (Optional[List[float]]): 4x4 tobase_extrinsic matrix representing the camera mounting offset
|
||||||
|
relative to the robot base (height + pitch).
|
||||||
|
length (Optional[int]): Length of the trajectory in number of frames.
|
||||||
|
depths (Optional[list[np.ndarray]]): List of depth images captured by this camera.
|
||||||
|
rgbs (Optional[list[np.ndarray]]): List of RGB images captured by this camera.
|
||||||
|
uv_tracks (Optional[Dict[str, Any]]): UV tracking data in the format
|
||||||
|
{mesh_name: {"per_frame": list, "width": W, "height": H}}.
|
||||||
|
uv_mesh_names (Optional[List[str]]): List of mesh names being tracked in the UV tracking data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trajectory: List[C2W]
|
||||||
|
intrinsic: List[float] = None
|
||||||
|
extrinsic: List[float] = None
|
||||||
|
length: int = None
|
||||||
|
depths: list[np.ndarray] = None
|
||||||
|
rgbs: list[np.ndarray] = None
|
||||||
|
uv_tracks: Optional[Dict[str, Any]] = None
|
||||||
|
uv_mesh_names: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.length is not None:
|
||||||
|
return self.length
|
||||||
|
self._check_length()
|
||||||
|
self.length = len(self.trajectory)
|
||||||
|
return len(self.trajectory)
|
||||||
|
|
||||||
|
def _check_length(self):
|
||||||
|
if self.depths is not None and len(self.depths) != len(self.trajectory):
|
||||||
|
raise ValueError("Length of depths does not match length of trajectory")
|
||||||
|
if self.rgbs is not None and len(self.rgbs) != len(self.trajectory):
|
||||||
|
raise ValueError("Length of rgbs does not match length of trajectory")
|
||||||
|
if self.uv_tracks is not None:
|
||||||
|
for mesh_name, track_data in self.uv_tracks.items():
|
||||||
|
if len(track_data["per_frame"]) != len(self.trajectory):
|
||||||
|
raise ValueError(f"Length of uv_tracks for mesh {mesh_name} does not match length of trajectory")
|
||||||
|
|
||||||
|
def append_rgb(self, rgb_image: np.ndarray):
|
||||||
|
if self.rgbs is None:
|
||||||
|
self.rgbs = []
|
||||||
|
self.rgbs.append(rgb_image)
|
||||||
|
|
||||||
|
def append_depth(self, depth_image: np.ndarray):
|
||||||
|
if self.depths is None:
|
||||||
|
self.depths = []
|
||||||
|
self.depths.append(depth_image)
|
||||||
95
nimbus/components/data/iterator.py
Normal file
95
nimbus/components/data/iterator.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=E0102
|
||||||
|
class Iterator(Iterator, Generic[T]):
|
||||||
|
def __init__(self, max_retry=3):
|
||||||
|
self._next_calls = 0.0
|
||||||
|
self._next_total_time = 0.0
|
||||||
|
self._init_time_costs = 0.0
|
||||||
|
self._init_times = 0
|
||||||
|
self._frame_compute_time = 0.0
|
||||||
|
self._frame_compute_frames = 0.0
|
||||||
|
self._frame_io_time = 0.0
|
||||||
|
self._frame_io_frames = 0.0
|
||||||
|
self._wait_time = 0.0
|
||||||
|
self._seq_num = 0.0
|
||||||
|
self._seq_time = 0.0
|
||||||
|
self.logger = logging.getLogger("de_logger")
|
||||||
|
self.max_retry = max_retry
|
||||||
|
self.retry_num = 0
|
||||||
|
|
||||||
|
def record_init_time(self, time_costs):
|
||||||
|
self._init_times += 1
|
||||||
|
self._init_time_costs += time_costs
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
result = self._next()
|
||||||
|
except StopIteration:
|
||||||
|
self._log_statistics()
|
||||||
|
raise
|
||||||
|
end_time = time.time()
|
||||||
|
self._next_calls += 1
|
||||||
|
self._next_total_time += end_time - start_time
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_compute_frame_info(self, length, time_costs):
|
||||||
|
self._frame_compute_frames += length
|
||||||
|
self._frame_compute_time += time_costs
|
||||||
|
|
||||||
|
def collect_io_frame_info(self, length, time_costs):
|
||||||
|
self._frame_io_frames += length
|
||||||
|
self._frame_io_time += time_costs
|
||||||
|
|
||||||
|
def collect_wait_time_info(self, time_costs):
|
||||||
|
self._wait_time += time_costs
|
||||||
|
|
||||||
|
def collect_seq_info(self, length, time_costs):
|
||||||
|
self._seq_num += length
|
||||||
|
self._seq_time += time_costs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _next(self):
|
||||||
|
raise NotImplementedError("Subclasses should implement this method.")
|
||||||
|
|
||||||
|
def _log_statistics(self):
|
||||||
|
class_name = self.__class__.__name__
|
||||||
|
self.logger.info(
|
||||||
|
f"{class_name}: Next method called {self._next_calls} times, total time:"
|
||||||
|
f" {self._next_total_time:.6f} seconds"
|
||||||
|
)
|
||||||
|
if self._init_time_costs > 0:
|
||||||
|
self.logger.info(
|
||||||
|
f"{class_name}: Init time: {self._init_time_costs:.6f} seconds, init {self._init_times} times"
|
||||||
|
)
|
||||||
|
if self._frame_compute_time > 0:
|
||||||
|
avg_compute_time = self._frame_compute_time / self._frame_compute_frames
|
||||||
|
self.logger.info(
|
||||||
|
f"{class_name}: compute frame num: {self._frame_compute_frames}, total time:"
|
||||||
|
f" {self._frame_compute_time:.6f} seconds, average time: {avg_compute_time:.6f} seconds per frame"
|
||||||
|
)
|
||||||
|
if self._frame_io_frames > 0:
|
||||||
|
avg_io_time = self._frame_io_time / self._frame_io_frames
|
||||||
|
self.logger.info(
|
||||||
|
f"{class_name}: io frame num: {self._frame_io_frames}, total time: {self._frame_io_time:.6f} seconds,"
|
||||||
|
f" average time: {avg_io_time:.6f} seconds per frame"
|
||||||
|
)
|
||||||
|
if self._wait_time > 0:
|
||||||
|
self.logger.info(f"{class_name}: wait time: {self._wait_time:.6f} seconds")
|
||||||
|
if self._seq_time > 0:
|
||||||
|
avg_seq_time = self._seq_time / self._seq_num
|
||||||
|
self.logger.info(
|
||||||
|
f"{class_name}: seq num: {self._seq_num:.6f}, total time: {self._seq_time:.6f} seconds, average time:"
|
||||||
|
f" {avg_seq_time:.6f} seconds per sequence"
|
||||||
|
)
|
||||||
119
nimbus/components/data/observation.py
Normal file
119
nimbus/components/data/observation.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from nimbus.components.data.camera import Camera
|
||||||
|
|
||||||
|
|
||||||
|
class Observations:
|
||||||
|
"""
|
||||||
|
Represents a single observation of a scene, which may include multiple camera trajectories and associated data.
|
||||||
|
Each observation is identified by a unique name and index, and can contain multiple Camera items that capture
|
||||||
|
different viewpoints or modalities of the same scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_name (str): The name of the scene associated with this observation.
|
||||||
|
index (str): The index or ID of this observation within the scene.
|
||||||
|
length (int): Optional total length of the observation. Calculated from camera trajectories if not provided.
|
||||||
|
data (dict): Optional dictionary for storing additional arbitrary data, such as metadata or annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_name: str, index: str, length: int = None, data: dict = None):
|
||||||
|
self.scene_name = scene_name
|
||||||
|
self.obs_name = scene_name + "_" + index
|
||||||
|
self.index = index
|
||||||
|
self.cam_items = []
|
||||||
|
self.length = length
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def append_cam(self, item: Camera):
|
||||||
|
self.cam_items.append(item)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.length is not None:
|
||||||
|
return self.length
|
||||||
|
self.length = 0
|
||||||
|
for cam in self.cam_items:
|
||||||
|
self.length += len(cam)
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
def flush_to_disk(self, path, video_fps=10):
|
||||||
|
path_to_save = os.path.join(path, "trajectory_" + self.index)
|
||||||
|
print(f"obs {self.obs_name} try to save path in {path_to_save}")
|
||||||
|
os.makedirs(path_to_save, exist_ok=True)
|
||||||
|
|
||||||
|
# Single camera: save in root directory
|
||||||
|
if len(self.cam_items) == 1:
|
||||||
|
cam = self.cam_items[0]
|
||||||
|
self._save_camera_data(path_to_save, cam, video_fps)
|
||||||
|
# Multiple cameras: save in camera_0/, camera_1/, etc.
|
||||||
|
else:
|
||||||
|
for idx, cam in enumerate(self.cam_items):
|
||||||
|
camera_dir = os.path.join(path_to_save, f"camera_{idx}")
|
||||||
|
os.makedirs(camera_dir, exist_ok=True)
|
||||||
|
self._save_camera_data(camera_dir, cam, video_fps)
|
||||||
|
|
||||||
|
def _save_camera_data(self, save_dir, cam: Camera, video_fps):
|
||||||
|
"""Helper method to save camera visualization data (rgbs, depths) to a directory."""
|
||||||
|
# Save RGB and depth images if available
|
||||||
|
if cam.rgbs is not None and len(cam.rgbs) > 0:
|
||||||
|
rgb_images_path = os.path.join(save_dir, "rgb/")
|
||||||
|
os.makedirs(rgb_images_path, exist_ok=True)
|
||||||
|
|
||||||
|
fps_path = os.path.join(save_dir, "fps.mp4")
|
||||||
|
|
||||||
|
for idx, rgb_item in enumerate(cam.rgbs):
|
||||||
|
rgb_filename = os.path.join(rgb_images_path, f"{idx}.jpg")
|
||||||
|
cv2.imwrite(rgb_filename, cv2.cvtColor(rgb_item, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
imageio.mimwrite(fps_path, cam.rgbs, fps=video_fps)
|
||||||
|
|
||||||
|
if cam.depths is not None and len(cam.depths) > 0:
|
||||||
|
depth_images_path = os.path.join(save_dir, "depth/")
|
||||||
|
os.makedirs(depth_images_path, exist_ok=True)
|
||||||
|
|
||||||
|
depth_path = os.path.join(save_dir, "depth.mp4")
|
||||||
|
|
||||||
|
# Create a copy for video (8-bit version)
|
||||||
|
depth_video_frames = []
|
||||||
|
for idx, depth_item in enumerate(cam.depths):
|
||||||
|
depth_filename = os.path.join(depth_images_path, f"{idx}.png")
|
||||||
|
cv2.imwrite(depth_filename, depth_item)
|
||||||
|
depth_video_frames.append((depth_item >> 8).astype(np.uint8))
|
||||||
|
|
||||||
|
imageio.mimwrite(depth_path, depth_video_frames, fps=video_fps)
|
||||||
|
|
||||||
|
# Save UV tracking visualizations if available
|
||||||
|
if cam.uv_tracks is not None and cam.uv_mesh_names is not None and cam.rgbs is not None:
|
||||||
|
num_frames = len(cam.rgbs)
|
||||||
|
try:
|
||||||
|
from nimbus_extension.components.render.brpc_utils.point_tracking import (
|
||||||
|
make_uv_overlays_and_video,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"UV tracking visualization requires nimbus_extension. "
|
||||||
|
"Please add `import nimbus_extension` before running the pipeline."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
make_uv_overlays_and_video(
|
||||||
|
cam.rgbs,
|
||||||
|
cam.uv_tracks,
|
||||||
|
cam.uv_mesh_names,
|
||||||
|
start_frame=0,
|
||||||
|
end_frame=num_frames,
|
||||||
|
fps=video_fps,
|
||||||
|
path_to_save=save_dir,
|
||||||
|
)
|
||||||
39
nimbus/components/data/package.py
Normal file
39
nimbus/components/data/package.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
class Package:
|
||||||
|
"""
|
||||||
|
A class representing a data package that can be serialized and deserialized for pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The actual data contained in the package, which can be of any type.
|
||||||
|
task_id (int): The ID of the task associated with this package.
|
||||||
|
task_name (str): The name of the task associated with this package.
|
||||||
|
stop_sig (bool): Whether this package signals the pipeline to stop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, task_id: int = -1, task_name: str = None, stop_sig: bool = False):
|
||||||
|
self.is_ser = False
|
||||||
|
self.data = data
|
||||||
|
self.task_id = task_id
|
||||||
|
self.task_name = task_name
|
||||||
|
self.stop_sig = stop_sig
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
assert self.is_ser is False, "data is already serialized"
|
||||||
|
self.data = pickle.dumps(self.data)
|
||||||
|
self.is_ser = True
|
||||||
|
|
||||||
|
def deserialize(self):
|
||||||
|
assert self.is_ser is True, "data is already deserialized"
|
||||||
|
self.data = pickle.loads(self.data)
|
||||||
|
self.is_ser = False
|
||||||
|
|
||||||
|
def is_serialized(self):
|
||||||
|
return self.is_ser
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def should_stop(self):
|
||||||
|
return self.stop_sig is True
|
||||||
69
nimbus/components/data/scene.py
Normal file
69
nimbus/components/data/scene.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
class Scene:
|
||||||
|
"""
|
||||||
|
Represents a loaded scene in the simulation environment, holding workflow context and task execution state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the scene or task.
|
||||||
|
pcd: Point cloud data associated with the scene.
|
||||||
|
scale (float): Scale factor for the scene geometry.
|
||||||
|
materials: Material data for the scene.
|
||||||
|
textures: Texture data for the scene.
|
||||||
|
floor_heights: Floor height information for the scene.
|
||||||
|
wf: The task workflow instance managing this scene.
|
||||||
|
task_id (int): The index of the current task within the workflow.
|
||||||
|
task_exec_num (int): The execution count for the current task, used for task repetition tracking.
|
||||||
|
simulation_app: The Isaac Sim SimulationApp instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = None,
|
||||||
|
pcd=None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
materials=None,
|
||||||
|
textures=None,
|
||||||
|
floor_heights=None,
|
||||||
|
wf=None,
|
||||||
|
task_id: int = None,
|
||||||
|
task_exec_num: int = 1,
|
||||||
|
simulation_app=None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.pcd = pcd
|
||||||
|
self.materials = materials
|
||||||
|
self.textures = textures
|
||||||
|
self.floor_heights = floor_heights
|
||||||
|
self.scale = scale
|
||||||
|
self.wf = wf
|
||||||
|
self.simulation_app = simulation_app
|
||||||
|
self.task_id = task_id
|
||||||
|
self.plan_info = None
|
||||||
|
self.generate_success = False
|
||||||
|
self.task_exec_num = task_exec_num
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
del state["pcd"]
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.pcd = None
|
||||||
|
|
||||||
|
def add_plan_info(self, plan_info):
|
||||||
|
self.plan_info = plan_info
|
||||||
|
|
||||||
|
def flush_to_disk(self, path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_from_disk(self, path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_generate_status(self, success):
|
||||||
|
self.generate_success = success
|
||||||
|
|
||||||
|
def get_generate_status(self):
|
||||||
|
return self.generate_success
|
||||||
|
|
||||||
|
def update_task_exec_num(self, num):
|
||||||
|
self.task_exec_num = num
|
||||||
145
nimbus/components/data/sequence.py
Normal file
145
nimbus/components/data/sequence.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import open3d as o3d
|
||||||
|
|
||||||
|
from nimbus.components.data.camera import C2W, Camera
|
||||||
|
|
||||||
|
|
||||||
|
class Sequence:
|
||||||
|
"""
|
||||||
|
Represents a camera trajectory sequence with associated metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_name (str): The name of the scene (e.g., room identifier).
|
||||||
|
index (str): The index or ID of this sequence within the scene.
|
||||||
|
length (int): Optional explicit sequence length. Calculated from camera trajectories if not provided.
|
||||||
|
data (dict): Optional additional arbitrary data associated with the sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_name: str, index: str, length: int = None, data: dict = None):
|
||||||
|
self.scene_name = scene_name
|
||||||
|
self.seq_name = scene_name + "_" + index
|
||||||
|
self.index = index
|
||||||
|
self.cam_items: list[Camera] = []
|
||||||
|
self.path_pcd = None
|
||||||
|
self.length = length
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["path_pcd_color"] = np.asarray(state["path_pcd"].colors)
|
||||||
|
state["path_pcd"] = o3d.io.write_point_cloud_to_bytes(state["path_pcd"], "mem::xyz")
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.path_pcd = o3d.io.read_point_cloud_from_bytes(state["path_pcd"], "mem::xyz")
|
||||||
|
self.path_pcd.colors = o3d.utility.Vector3dVector(state["path_pcd_color"])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.length is not None:
|
||||||
|
return self.length
|
||||||
|
self.length = 0
|
||||||
|
for cam in self.cam_items:
|
||||||
|
self.length += len(cam)
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def append_cam(self, item: Camera):
|
||||||
|
self.cam_items.append(item)
|
||||||
|
|
||||||
|
def update_pcd(self, path_pcd):
|
||||||
|
self.path_pcd = path_pcd
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
def flush_to_disk(self, path):
|
||||||
|
path_to_save = os.path.join(path, "trajectory_" + self.index)
|
||||||
|
print(f"seq {self.seq_name} try to save path in {path_to_save}")
|
||||||
|
os.makedirs(path_to_save, exist_ok=True)
|
||||||
|
if self.path_pcd is not None:
|
||||||
|
pcd_path = os.path.join(path_to_save, "path.ply")
|
||||||
|
o3d.io.write_point_cloud(pcd_path, self.path_pcd)
|
||||||
|
|
||||||
|
# Single camera: save in root directory
|
||||||
|
if len(self.cam_items) == 1:
|
||||||
|
cam = self.cam_items[0]
|
||||||
|
camera_trajectory_list = [t.matrix for t in cam.trajectory]
|
||||||
|
save_dict = {
|
||||||
|
"camera_intrinsic": cam.intrinsic if cam.intrinsic is not None else None,
|
||||||
|
"camera_extrinsic": cam.extrinsic if cam.extrinsic is not None else None,
|
||||||
|
"camera_trajectory": camera_trajectory_list,
|
||||||
|
}
|
||||||
|
traj_path = os.path.join(path_to_save, "data.json")
|
||||||
|
json_object = json.dumps(save_dict, indent=4)
|
||||||
|
with open(traj_path, "w", encoding="utf-8") as outfile:
|
||||||
|
outfile.write(json_object)
|
||||||
|
# Multiple cameras: save in camera_0/, camera_1/, etc.
|
||||||
|
else:
|
||||||
|
for idx, cam in enumerate(self.cam_items):
|
||||||
|
camera_dir = os.path.join(path_to_save, f"camera_{idx}")
|
||||||
|
os.makedirs(camera_dir, exist_ok=True)
|
||||||
|
camera_trajectory_list = [t.matrix for t in cam.trajectory]
|
||||||
|
save_dict = {
|
||||||
|
"camera_intrinsic": cam.intrinsic if cam.intrinsic is not None else None,
|
||||||
|
"camera_extrinsic": cam.extrinsic if cam.extrinsic is not None else None,
|
||||||
|
"camera_trajectory": camera_trajectory_list,
|
||||||
|
}
|
||||||
|
traj_path = os.path.join(camera_dir, "data.json")
|
||||||
|
json_object = json.dumps(save_dict, indent=4)
|
||||||
|
with open(traj_path, "w", encoding="utf-8") as outfile:
|
||||||
|
outfile.write(json_object)
|
||||||
|
|
||||||
|
def load_from_disk(self, path):
|
||||||
|
print(f"seq {self.seq_name} try to load path from {path}")
|
||||||
|
|
||||||
|
pcd_path = os.path.join(path, "path.ply")
|
||||||
|
if os.path.exists(pcd_path):
|
||||||
|
self.path_pcd = o3d.io.read_point_cloud(pcd_path)
|
||||||
|
|
||||||
|
# Clear existing camera items
|
||||||
|
self.cam_items = []
|
||||||
|
|
||||||
|
# Check if single camera format (data.json in root)
|
||||||
|
traj_path = os.path.join(path, "data.json")
|
||||||
|
if os.path.exists(traj_path):
|
||||||
|
with open(traj_path, "r", encoding="utf-8") as infile:
|
||||||
|
data = json.load(infile)
|
||||||
|
|
||||||
|
camera_trajectory_list = []
|
||||||
|
for trajectory in data["camera_trajectory"]:
|
||||||
|
camera_trajectory_list.append(C2W(matrix=trajectory))
|
||||||
|
|
||||||
|
cam = Camera(
|
||||||
|
trajectory=camera_trajectory_list,
|
||||||
|
intrinsic=data.get("camera_intrinsic"),
|
||||||
|
extrinsic=data.get("camera_extrinsic"),
|
||||||
|
)
|
||||||
|
self.cam_items.append(cam)
|
||||||
|
else:
|
||||||
|
# Multiple camera format (camera_0/, camera_1/, etc.)
|
||||||
|
idx = 0
|
||||||
|
while True:
|
||||||
|
camera_dir = os.path.join(path, f"camera_{idx}")
|
||||||
|
camera_json = os.path.join(camera_dir, "data.json")
|
||||||
|
if not os.path.exists(camera_json):
|
||||||
|
break
|
||||||
|
|
||||||
|
with open(camera_json, "r", encoding="utf-8") as infile:
|
||||||
|
data = json.load(infile)
|
||||||
|
|
||||||
|
camera_trajectory_list = []
|
||||||
|
for trajectory in data["camera_trajectory"]:
|
||||||
|
camera_trajectory_list.append(C2W(matrix=trajectory))
|
||||||
|
|
||||||
|
cam = Camera(
|
||||||
|
trajectory=camera_trajectory_list,
|
||||||
|
intrinsic=data.get("camera_intrinsic"),
|
||||||
|
extrinsic=data.get("camera_extrinsic"),
|
||||||
|
)
|
||||||
|
self.cam_items.append(cam)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
assert len(self.cam_items) > 0, f"No camera data found in {path}"
|
||||||
7
nimbus/components/dedump/__init__.py
Normal file
7
nimbus/components/dedump/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
|
||||||
|
dedumper_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls: Iterator):
|
||||||
|
dedumper_dict[type_name] = cls
|
||||||
7
nimbus/components/dump/__init__.py
Normal file
7
nimbus/components/dump/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .base_dumper import BaseDumper
|
||||||
|
|
||||||
|
dumper_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls: BaseDumper):
|
||||||
|
dumper_dict[type_name] = cls
|
||||||
82
nimbus/components/dump/base_dumper.py
Normal file
82
nimbus/components/dump/base_dumper.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from pympler import asizeof
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.package import Package
|
||||||
|
from nimbus.utils.utils import unpack_iter_data
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDumper(Iterator):
|
||||||
|
def __init__(self, data_iter, output_queue, max_queue_num=1):
|
||||||
|
super().__init__()
|
||||||
|
self.data_iter = data_iter
|
||||||
|
self.scene = None
|
||||||
|
self.output_queue = output_queue
|
||||||
|
self.total_case = 0
|
||||||
|
self.success_case = 0
|
||||||
|
self.max_queue_num = max_queue_num
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _next(self):
|
||||||
|
try:
|
||||||
|
data = next(self.data_iter)
|
||||||
|
scene, seq, obs = unpack_iter_data(data)
|
||||||
|
self.total_case += 1
|
||||||
|
if scene is not None:
|
||||||
|
if self.scene is not None and (
|
||||||
|
scene.task_id != self.scene.task_id
|
||||||
|
or scene.name != self.scene.name
|
||||||
|
or scene.task_exec_num != self.scene.task_exec_num
|
||||||
|
):
|
||||||
|
self.logger.info(
|
||||||
|
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
||||||
|
)
|
||||||
|
self.total_case = 1
|
||||||
|
self.success_case = 0
|
||||||
|
self.scene = scene
|
||||||
|
if obs is None and seq is None:
|
||||||
|
self.logger.info(f"generate failed, skip once! success rate: {self.success_case}/{self.total_case}")
|
||||||
|
if self.scene is not None:
|
||||||
|
self.scene.update_generate_status(success=False)
|
||||||
|
return None
|
||||||
|
io_start_time = time.time()
|
||||||
|
if self.output_queue is not None:
|
||||||
|
obj = self.dump(seq, obs)
|
||||||
|
pack = Package(obj, task_id=scene.task_id, task_name=scene.name)
|
||||||
|
pack.serialize()
|
||||||
|
|
||||||
|
wait_time = time.time()
|
||||||
|
while self.output_queue.qsize() >= self.max_queue_num:
|
||||||
|
time.sleep(1)
|
||||||
|
end_time = time.time()
|
||||||
|
self.collect_wait_time_info(end_time - wait_time)
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
self.output_queue.put(pack)
|
||||||
|
ed = time.time()
|
||||||
|
self.logger.info(f"put time: {ed - st}, data size: {asizeof.asizeof(obj)}")
|
||||||
|
else:
|
||||||
|
obj = self.dump(seq, obs)
|
||||||
|
self.success_case += 1
|
||||||
|
self.scene.update_generate_status(success=True)
|
||||||
|
self.collect_seq_info(1, time.time() - io_start_time)
|
||||||
|
except StopIteration:
|
||||||
|
if self.output_queue is not None:
|
||||||
|
pack = Package(None, stop_sig=True)
|
||||||
|
self.output_queue.put(pack)
|
||||||
|
if self.scene is not None:
|
||||||
|
self.logger.info(
|
||||||
|
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
||||||
|
)
|
||||||
|
raise StopIteration("no data")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error during data dumping: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dump(self, seq, obs):
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
16
nimbus/components/load/__init__.py
Normal file
16
nimbus/components/load/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# flake8: noqa: F401
|
||||||
|
# pylint: disable=C0413
|
||||||
|
|
||||||
|
from .base_randomizer import LayoutRandomizer
|
||||||
|
from .base_scene_loader import SceneLoader
|
||||||
|
|
||||||
|
scene_loader_dict = {}
|
||||||
|
layout_randomizer_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_loader(type_name: str, cls: SceneLoader):
|
||||||
|
scene_loader_dict[type_name] = cls
|
||||||
|
|
||||||
|
|
||||||
|
def register_randomizer(type_name: str, cls: LayoutRandomizer):
|
||||||
|
layout_randomizer_dict[type_name] = cls
|
||||||
72
nimbus/components/load/base_randomizer.py
Normal file
72
nimbus/components/load/base_randomizer.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.daemon.decorators import status_monitor
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutRandomizer(Iterator):
|
||||||
|
"""
|
||||||
|
Base class for layout randomization in a scene. This class defines the structure for randomizing scenes and
|
||||||
|
tracking the randomization process. It manages the current scene, randomization count, and provides hooks for
|
||||||
|
subclasses to implement specific randomization logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_iter (Iterator): An iterator that provides scenes to be randomized.
|
||||||
|
random_num (int): The number of randomizations to perform for each scene before moving to the next one.
|
||||||
|
strict_mode (bool): If True, the randomizer will check the generation status of the current scene and retry
|
||||||
|
randomization if it was not successful. This ensures that only successfully generated
|
||||||
|
scenes are counted towards the randomization limit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_iter: Iterator, random_num: int, strict_mode: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.scene_iter = scene_iter
|
||||||
|
self.random_num = random_num
|
||||||
|
self.strict_mode = strict_mode
|
||||||
|
self.cur_index = sys.maxsize
|
||||||
|
self.scene: Optional[Scene] = None
|
||||||
|
|
||||||
|
def reset(self, scene):
|
||||||
|
self.cur_index = 0
|
||||||
|
self.scene = scene
|
||||||
|
|
||||||
|
def _fetch_next_scene(self):
|
||||||
|
scene = next(self.scene_iter)
|
||||||
|
self.reset(scene)
|
||||||
|
|
||||||
|
@status_monitor()
|
||||||
|
def _randomize_with_status(self, scene) -> Scene:
|
||||||
|
scene = self.randomize_scene(self.scene)
|
||||||
|
return scene
|
||||||
|
|
||||||
|
def _next(self) -> Scene:
|
||||||
|
try:
|
||||||
|
if self.strict_mode and self.scene is not None:
|
||||||
|
if not self.scene.get_generate_status():
|
||||||
|
self.logger.info("strict_mode is open, retry the randomization to generate sequence.")
|
||||||
|
st = time.time()
|
||||||
|
scene = self._randomize_with_status(self.scene)
|
||||||
|
self.collect_seq_info(1, time.time() - st)
|
||||||
|
return scene
|
||||||
|
if self.cur_index >= self.random_num:
|
||||||
|
self._fetch_next_scene()
|
||||||
|
if self.cur_index < self.random_num:
|
||||||
|
st = time.time()
|
||||||
|
scene = self._randomize_with_status(self.scene)
|
||||||
|
self.collect_seq_info(1, time.time() - st)
|
||||||
|
self.cur_index += 1
|
||||||
|
return scene
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more scenes to randomize.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error during scene idx {self.cur_index} randomization: {e}")
|
||||||
|
self.cur_index += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def randomize_scene(self, scene) -> Scene:
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
41
nimbus/components/load/base_scene_loader.py
Normal file
41
nimbus/components/load/base_scene_loader.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
|
||||||
|
|
||||||
|
class SceneLoader(Iterator):
|
||||||
|
"""
|
||||||
|
Base class for scene loading in a simulation environment. This class defines the structure for loading scenes
|
||||||
|
and tracking the loading process. It manages the current package iterator and provides hooks for subclasses
|
||||||
|
to implement specific scene loading logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_iter (Iterator): An iterator that provides packages containing scene information to be loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pack_iter):
|
||||||
|
super().__init__()
|
||||||
|
self.pack_iter = pack_iter
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_asset(self) -> Scene:
|
||||||
|
"""
|
||||||
|
Abstract method to load and initialize a scene.
|
||||||
|
|
||||||
|
Subclasses must implement this method to define the specific logic for creating and configuring
|
||||||
|
a scene object based on the current state of the iterator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scene: A fully initialized Scene object.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method must be implemented by subclasses")
|
||||||
|
|
||||||
|
def _next(self) -> Scene:
|
||||||
|
try:
|
||||||
|
return self.load_asset()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more scenes to load.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error during scene loading: {e}")
|
||||||
|
raise e
|
||||||
7
nimbus/components/plan_with_render/__init__.py
Normal file
7
nimbus/components/plan_with_render/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
|
||||||
|
plan_with_render_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls: Iterator):
|
||||||
|
plan_with_render_dict[type_name] = cls
|
||||||
7
nimbus/components/planner/__init__.py
Normal file
7
nimbus/components/planner/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .base_seq_planner import SequencePlanner
|
||||||
|
|
||||||
|
seq_planner_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls: SequencePlanner):
|
||||||
|
seq_planner_dict[type_name] = cls
|
||||||
102
nimbus/components/planner/base_seq_planner.py
Normal file
102
nimbus/components/planner/base_seq_planner.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.data.sequence import Sequence
|
||||||
|
from nimbus.daemon.decorators import status_monitor
|
||||||
|
from nimbus.utils.flags import is_debug_mode
|
||||||
|
from nimbus.utils.types import ARGS, TYPE
|
||||||
|
|
||||||
|
from .planner import path_planner_dict
|
||||||
|
|
||||||
|
|
||||||
|
class SequencePlanner(Iterator):
|
||||||
|
"""
|
||||||
|
A base class for sequence planning in a simulation environment. This class defines the structure for generating
|
||||||
|
sequences based on scenes and tracking the planning process. It manages the current scene, episode count
|
||||||
|
and provides hooks for subclasses to implement specific sequence generation logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_iter (Iterator): An iterator that provides scenes to be processed for sequence planning.
|
||||||
|
planner_cfg (dict): A dictionary containing configuration parameters for the planner,
|
||||||
|
such as the type of planner to use and its arguments.
|
||||||
|
episodes (int): The number of episodes to generate for each scene before moving to the next one. Default is 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_iter: Iterator[Scene], planner_cfg: dict, episodes: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.scene_iter = scene_iter
|
||||||
|
self.planner_cfg = planner_cfg
|
||||||
|
self.episodes = episodes
|
||||||
|
self.current_episode = sys.maxsize
|
||||||
|
self.scene: Optional[Scene] = None
|
||||||
|
|
||||||
|
@status_monitor()
|
||||||
|
def _plan_with_status(self) -> Optional[Sequence]:
|
||||||
|
seq = self.generate_sequence()
|
||||||
|
return seq
|
||||||
|
|
||||||
|
def _next(self) -> tuple[Scene, Sequence]:
|
||||||
|
try:
|
||||||
|
if self.scene is None or self.current_episode >= self.episodes:
|
||||||
|
try:
|
||||||
|
self.scene = next(self.scene_iter)
|
||||||
|
self.current_episode = 0
|
||||||
|
if self.scene is None:
|
||||||
|
return None, None
|
||||||
|
self.initialize(self.scene)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more scene to process.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error loading next scene: {e}")
|
||||||
|
if is_debug_mode():
|
||||||
|
raise e
|
||||||
|
self.current_episode = sys.maxsize
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
compute_start_time = time.time()
|
||||||
|
seq = self._plan_with_status()
|
||||||
|
compute_end_time = time.time()
|
||||||
|
self.current_episode += 1
|
||||||
|
|
||||||
|
if seq is not None:
|
||||||
|
self.collect_compute_frame_info(seq.get_length(), compute_end_time - compute_start_time)
|
||||||
|
return self.scene, seq
|
||||||
|
|
||||||
|
if self.current_episode >= self.episodes:
|
||||||
|
return self.scene, None
|
||||||
|
|
||||||
|
self.logger.info(f"Generate seq failed and retry. Current episode id is {self.current_episode}")
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more scene to process.")
|
||||||
|
except Exception as e:
|
||||||
|
scene_name = getattr(self.scene, "name", "<unknown>")
|
||||||
|
self.logger.exception(
|
||||||
|
f"Error during idx {self.current_episode} sequence generation for scene {scene_name}: {e}"
|
||||||
|
)
|
||||||
|
if is_debug_mode():
|
||||||
|
raise e
|
||||||
|
self.current_episode += 1
|
||||||
|
return self.scene, None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_sequence(self) -> Optional[Sequence]:
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
|
def _initialize(self, scene):
|
||||||
|
if self.planner_cfg is not None:
|
||||||
|
self.logger.info(f"init {self.planner_cfg[TYPE]} planner in seq_planner")
|
||||||
|
self.planner = path_planner_dict[self.planner_cfg[TYPE]](scene, **self.planner_cfg.get(ARGS, {}))
|
||||||
|
else:
|
||||||
|
self.planner = None
|
||||||
|
self.logger.info("planner config is None in seq_planner and skip initialize")
|
||||||
|
|
||||||
|
def initialize(self, scene):
|
||||||
|
init_start_time = time.time()
|
||||||
|
self._initialize(scene)
|
||||||
|
self.record_init_time(time.time() - init_start_time)
|
||||||
5
nimbus/components/planner/planner/__init__.py
Normal file
5
nimbus/components/planner/planner/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
path_planner_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls):
|
||||||
|
path_planner_dict[type_name] = cls
|
||||||
7
nimbus/components/render/__init__.py
Normal file
7
nimbus/components/render/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .base_renderer import BaseRenderer
|
||||||
|
|
||||||
|
renderer_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls: BaseRenderer):
|
||||||
|
renderer_dict[type_name] = cls
|
||||||
80
nimbus/components/render/base_renderer.py
Normal file
80
nimbus/components/render/base_renderer.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.observation import Observations
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.data.sequence import Sequence
|
||||||
|
from nimbus.daemon.decorators import status_monitor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRenderer(Iterator):
|
||||||
|
"""
|
||||||
|
Base class for rendering in a simulation environment. This class defines the structure for rendering scenes and
|
||||||
|
tracking the rendering process. It manages the current scene and provides hooks for subclasses to implement
|
||||||
|
specific rendering logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_seq_iter (Iterator): An iterator that provides pairs of scenes and sequences to be rendered. Each item
|
||||||
|
from the iterator should be a tuple containing a scene and its corresponding sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_seq_iter: Iterator[tuple[Scene, Sequence]]):
|
||||||
|
super().__init__()
|
||||||
|
self.scene_seq_iter = scene_seq_iter
|
||||||
|
self.scene: Optional[Scene] = None
|
||||||
|
|
||||||
|
@status_monitor()
|
||||||
|
def _generate_obs_with_status(self, seq) -> Optional[Observations]:
|
||||||
|
compute_start_time = time.time()
|
||||||
|
obs = self.generate_obs(seq)
|
||||||
|
end_start_time = time.time()
|
||||||
|
if obs is not None:
|
||||||
|
self.collect_compute_frame_info(len(obs), end_start_time - compute_start_time)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _next(self):
|
||||||
|
try:
|
||||||
|
scene, seq = next(self.scene_seq_iter)
|
||||||
|
if scene is not None:
|
||||||
|
if self.scene is None:
|
||||||
|
self.reset(scene)
|
||||||
|
elif scene.task_id != self.scene.task_id or scene.name != self.scene.name:
|
||||||
|
self.logger.info(f"Scene changed: {self.scene.name} -> {scene.name}")
|
||||||
|
self.reset(scene)
|
||||||
|
if seq is None:
|
||||||
|
return scene, None, None
|
||||||
|
obs = self._generate_obs_with_status(seq)
|
||||||
|
if obs is None:
|
||||||
|
return scene, None, None
|
||||||
|
return scene, seq, obs
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more sequences to process.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error during rendering: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_obs(self, seq) -> Optional[Observations]:
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _lazy_init(self):
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _close_resource(self):
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
|
def reset(self, scene):
|
||||||
|
try:
|
||||||
|
self.scene = scene
|
||||||
|
self._close_resource()
|
||||||
|
init_start_time = time.time()
|
||||||
|
self._lazy_init()
|
||||||
|
self.record_init_time(time.time() - init_start_time)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error initializing renderer: {e}")
|
||||||
|
self.scene = None
|
||||||
|
raise e
|
||||||
7
nimbus/components/store/__init__.py
Normal file
7
nimbus/components/store/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .base_writer import BaseWriter
|
||||||
|
|
||||||
|
writer_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(type_name: str, cls: BaseWriter):
|
||||||
|
writer_dict[type_name] = cls
|
||||||
163
nimbus/components/store/base_writer.py
Normal file
163
nimbus/components/store/base_writer.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.observation import Observations
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.data.sequence import Sequence
|
||||||
|
from nimbus.daemon import ComponentStatus, StatusReporter
|
||||||
|
from nimbus.utils.flags import is_debug_mode
|
||||||
|
from nimbus.utils.utils import unpack_iter_data
|
||||||
|
|
||||||
|
|
||||||
|
def run_batch(func, args):
|
||||||
|
for arg in args:
|
||||||
|
func(*arg)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseWriter(Iterator):
|
||||||
|
"""
|
||||||
|
A base class for writing generated sequences and observations to disk. This class defines the structure for
|
||||||
|
writing data and tracking the writing process. It manages the current scene, success and total case counts,
|
||||||
|
and provides hooks for subclasses to implement specific data writing logic. The writer supports both synchronous
|
||||||
|
and asynchronous batch writing modes, allowing for efficient data handling in various scenarios.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_iter (Iterator): An iterator that provides data to be written, typically containing scenes,
|
||||||
|
sequences, and observations.
|
||||||
|
seq_output_dir (str): The directory where generated sequences will be saved. Can be None
|
||||||
|
if sequence output is not needed.
|
||||||
|
obs_output_dir (str): The directory where generated observations will be saved. Can be None
|
||||||
|
if observation output is not needed.
|
||||||
|
batch_async (bool): If True, the writer will use asynchronous batch writing to improve performance
|
||||||
|
when handling large amounts of data. Default is True.
|
||||||
|
async_threshold (int): The maximum number of asynchronous write operations that can be in progress
|
||||||
|
at the same time. If the threshold is reached, the writer will wait for the oldest operation
|
||||||
|
to complete before starting a new one. Default is 1.
|
||||||
|
batch_size (int): The number of data items to write in each batch when using asynchronous writing.
|
||||||
|
Default is 2, and it will be capped at 8 to prevent potential issues with too many concurrent operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_iter: Iterator[tuple[Scene, Sequence, Observations]],
|
||||||
|
seq_output_dir: str,
|
||||||
|
obs_output_dir: str,
|
||||||
|
batch_async: bool = True,
|
||||||
|
async_threshold: int = 1,
|
||||||
|
batch_size: int = 2,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
seq_output_dir is not None or obs_output_dir is not None
|
||||||
|
), "At least one output directory must be provided"
|
||||||
|
self.data_iter = data_iter
|
||||||
|
self.seq_output_dir = seq_output_dir
|
||||||
|
self.obs_output_dir = obs_output_dir
|
||||||
|
self.scene = None
|
||||||
|
self.async_mode = batch_async
|
||||||
|
self.batch_size = batch_size if batch_size <= 8 else 8
|
||||||
|
if batch_async and batch_size > self.batch_size:
|
||||||
|
self.logger.info("Batch size is larger than 8(probably cause program hang), batch size will be set to 8")
|
||||||
|
self.async_threshold = async_threshold
|
||||||
|
self.flush_executor = ThreadPoolExecutor(max_workers=max(1, 64 // self.batch_size))
|
||||||
|
self.flush_threads = []
|
||||||
|
self.data_buffer = []
|
||||||
|
self.logger.info(
|
||||||
|
f"Batch Async Write Mode: {self.async_mode}, async threshold: {self.async_threshold}, batch size:"
|
||||||
|
f" {self.batch_size}"
|
||||||
|
)
|
||||||
|
self.total_case = 0
|
||||||
|
self.success_case = 0
|
||||||
|
self.last_scene_key = None
|
||||||
|
self.status_reporter = StatusReporter(self.__class__.__name__)
|
||||||
|
|
||||||
|
def _next(self):
|
||||||
|
try:
|
||||||
|
data = next(self.data_iter)
|
||||||
|
scene, seq, obs = unpack_iter_data(data)
|
||||||
|
|
||||||
|
new_key = (scene.task_id, scene.name, scene.task_exec_num) if scene is not None else None
|
||||||
|
|
||||||
|
self.scene = scene
|
||||||
|
|
||||||
|
if new_key != self.last_scene_key:
|
||||||
|
if self.scene is not None and self.last_scene_key is not None:
|
||||||
|
self.logger.info(
|
||||||
|
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
||||||
|
)
|
||||||
|
self.success_case = 0
|
||||||
|
self.total_case = 0
|
||||||
|
self.last_scene_key = new_key
|
||||||
|
|
||||||
|
if self.scene is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.total_case += 1
|
||||||
|
|
||||||
|
self.status_reporter.update_status(ComponentStatus.RUNNING)
|
||||||
|
if seq is None and obs is None:
|
||||||
|
self.logger.info(f"generate failed, skip once! success rate: {self.success_case}/{self.total_case}")
|
||||||
|
self.scene.update_generate_status(success=False)
|
||||||
|
return None
|
||||||
|
scene_name = self.scene.name
|
||||||
|
io_start_time = time.time()
|
||||||
|
if self.async_mode:
|
||||||
|
cp_start_time = time.time()
|
||||||
|
cp = copy(self.scene.wf)
|
||||||
|
cp_end_time = time.time()
|
||||||
|
if self.scene.wf is not None:
|
||||||
|
self.logger.info(f"Scene {scene_name} workflow copy time: {cp_end_time - cp_start_time:.2f}s")
|
||||||
|
self.data_buffer.append((cp, scene_name, seq, obs))
|
||||||
|
if len(self.data_buffer) >= self.batch_size:
|
||||||
|
self.flush_threads = [t for t in self.flush_threads if not t.done()]
|
||||||
|
|
||||||
|
if len(self.flush_threads) >= self.async_threshold:
|
||||||
|
self.logger.info("Max async workers reached, waiting for the oldest thread to finish")
|
||||||
|
self.flush_threads[0].result()
|
||||||
|
self.flush_threads = self.flush_threads[1:]
|
||||||
|
|
||||||
|
to_flush_buffer = self.data_buffer.copy()
|
||||||
|
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, to_flush_buffer)
|
||||||
|
if is_debug_mode():
|
||||||
|
async_flush.result() # surface exceptions immediately in debug mode
|
||||||
|
self.flush_threads.append(async_flush)
|
||||||
|
self.data_buffer = []
|
||||||
|
flush_length = len(obs) if obs is not None else len(seq)
|
||||||
|
else:
|
||||||
|
flush_length = self.flush_to_disk(self.scene.wf, scene_name, seq, obs)
|
||||||
|
self.success_case += 1
|
||||||
|
self.scene.update_generate_status(success=True)
|
||||||
|
self.collect_io_frame_info(flush_length, time.time() - io_start_time)
|
||||||
|
self.status_reporter.update_status(ComponentStatus.COMPLETED)
|
||||||
|
return None
|
||||||
|
except StopIteration:
|
||||||
|
if self.async_mode:
|
||||||
|
if len(self.data_buffer) > 0:
|
||||||
|
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, self.data_buffer)
|
||||||
|
self.flush_threads.append(async_flush)
|
||||||
|
for thread in self.flush_threads:
|
||||||
|
thread.result()
|
||||||
|
if self.scene is not None:
|
||||||
|
self.logger.info(
|
||||||
|
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
||||||
|
)
|
||||||
|
raise StopIteration("no data")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error during data writing: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
for thread in self.flush_threads:
|
||||||
|
thread.result()
|
||||||
|
self.logger.info(f"Writer {len(self.flush_threads)} threads closed")
|
||||||
|
# Close the simulation app if it exists
|
||||||
|
if self.scene is not None and self.scene.simulation_app is not None:
|
||||||
|
self.logger.info("Closing simulation app")
|
||||||
|
self.scene.simulation_app.close()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def flush_to_disk(self, task, scene_name, seq, obs):
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
4
nimbus/daemon/__init__.py
Normal file
4
nimbus/daemon/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# flake8: noqa: E401
|
||||||
|
from .status import ComponentStatus, StatusInfo
|
||||||
|
from .status_monitor import StatusMonitor
|
||||||
|
from .status_reporter import StatusReporter
|
||||||
24
nimbus/daemon/decorators.py
Normal file
24
nimbus/daemon/decorators.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from nimbus.daemon import ComponentStatus, StatusReporter
|
||||||
|
|
||||||
|
|
||||||
|
def status_monitor(running_status=ComponentStatus.RUNNING, completed_status=ComponentStatus.COMPLETED):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if not hasattr(self, "status_reporter"):
|
||||||
|
self.status_reporter = StatusReporter(self.__class__.__name__)
|
||||||
|
|
||||||
|
self.status_reporter.update_status(running_status)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
self.status_reporter.update_status(completed_status)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
21
nimbus/daemon/status.py
Normal file
21
nimbus/daemon/status.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentStatus(Enum):
|
||||||
|
IDLE = "idle"
|
||||||
|
READY = "ready"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StatusInfo:
|
||||||
|
component_id: str
|
||||||
|
status: ComponentStatus
|
||||||
|
last_update: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def get_status_duration(self) -> float:
|
||||||
|
return time.time() - self.last_update
|
||||||
160
nimbus/daemon/status_monitor.py
Normal file
160
nimbus/daemon/status_monitor.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import threading
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .status import ComponentStatus, StatusInfo
|
||||||
|
|
||||||
|
|
||||||
|
class StatusMonitor:
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUTS = {
|
||||||
|
ComponentStatus.IDLE: 100,
|
||||||
|
ComponentStatus.READY: float("inf"),
|
||||||
|
ComponentStatus.RUNNING: 360,
|
||||||
|
ComponentStatus.COMPLETED: float("inf"),
|
||||||
|
ComponentStatus.TIMEOUT: float("inf"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(self, "initialized"):
|
||||||
|
self.components: Dict[str, StatusInfo] = {}
|
||||||
|
self.status_timeouts = self.DEFAULT_TIMEOUTS.copy()
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def set_logger(self, logger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def set_status_timeout(self, status: ComponentStatus, timeout_seconds: float):
|
||||||
|
self.status_timeouts[status] = timeout_seconds
|
||||||
|
|
||||||
|
def set_component_timeouts(self, timeouts: Dict[str, float]):
|
||||||
|
converted_timeouts = {}
|
||||||
|
|
||||||
|
for status_name, timeout_value in timeouts.items():
|
||||||
|
try:
|
||||||
|
if isinstance(status_name, str):
|
||||||
|
status = ComponentStatus[status_name.upper()]
|
||||||
|
elif isinstance(status_name, ComponentStatus):
|
||||||
|
status = status_name
|
||||||
|
else:
|
||||||
|
self._record(
|
||||||
|
f"Warning: Invalid status type '{type(status_name)}' for status '{status_name}', skipping"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
timeout_value = float(timeout_value)
|
||||||
|
if timeout_value < 0:
|
||||||
|
timeout_value = float("inf")
|
||||||
|
|
||||||
|
converted_timeouts[status] = timeout_value
|
||||||
|
self._record(f"Set timeout for {status.value}: {timeout_value}s")
|
||||||
|
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
self._record(
|
||||||
|
f"Warning: Invalid timeout value '{timeout_value}' for status '{status_name}': {e}, skipping"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
self._record(
|
||||||
|
f"Warning: Unknown status '{status_name}', skipping. Available statuses:"
|
||||||
|
f" {[s.name for s in ComponentStatus]}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
self._record(f"Error processing status '{status_name}': {e}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.status_timeouts.update(converted_timeouts)
|
||||||
|
|
||||||
|
def register_update(self, status_info: StatusInfo):
|
||||||
|
self.components[status_info.component_id] = status_info
|
||||||
|
|
||||||
|
def get_all_status(self) -> Dict[str, StatusInfo]:
|
||||||
|
return self.components.copy()
|
||||||
|
|
||||||
|
def get_status(self, component_id: str) -> Optional[StatusInfo]:
|
||||||
|
return self.components.get(component_id)
|
||||||
|
|
||||||
|
def get_timeout_components(self) -> Dict[str, StatusInfo]:
|
||||||
|
timeout_components = {}
|
||||||
|
for component_id, status_info in self.components.items():
|
||||||
|
if status_info.status == ComponentStatus.TIMEOUT:
|
||||||
|
timeout_components[component_id] = status_info
|
||||||
|
return timeout_components
|
||||||
|
|
||||||
|
def get_components_length(self):
|
||||||
|
return len(self.components)
|
||||||
|
|
||||||
|
def check_and_update_timeouts(self) -> Dict[str, StatusInfo]:
|
||||||
|
newly_timeout_components = {}
|
||||||
|
components = self.get_all_status()
|
||||||
|
for component_id, status_info in components.items():
|
||||||
|
if status_info.status == ComponentStatus.TIMEOUT:
|
||||||
|
newly_timeout_components[component_id] = status_info
|
||||||
|
continue
|
||||||
|
|
||||||
|
time_since_update = status_info.get_status_duration()
|
||||||
|
timeout_threshold = self.status_timeouts.get(status_info.status, 300)
|
||||||
|
self._record(
|
||||||
|
f"[COMPONENT DETAIL] {component_id}: "
|
||||||
|
f"Status={status_info.status}, "
|
||||||
|
f"Duration={status_info.get_status_duration():.1f}s, "
|
||||||
|
f"Threshold={timeout_threshold}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
if time_since_update > timeout_threshold:
|
||||||
|
self._record(
|
||||||
|
f"Component {component_id} timeout: {status_info.status.value} for {time_since_update:.1f}s"
|
||||||
|
f" (threshold: {timeout_threshold}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
status_info.status = ComponentStatus.TIMEOUT
|
||||||
|
status_info.last_update = time_since_update
|
||||||
|
newly_timeout_components[component_id] = status_info
|
||||||
|
|
||||||
|
return newly_timeout_components
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.components.clear()
|
||||||
|
self._record("Cleared all registered components.")
|
||||||
|
|
||||||
|
def get_component_status_duration(self, component_id: str) -> Optional[float]:
|
||||||
|
status_info = self.components.get(component_id)
|
||||||
|
if status_info:
|
||||||
|
return status_info.get_status_duration()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_status_with_duration(self) -> Dict[str, Dict]:
|
||||||
|
result = {}
|
||||||
|
for comp_id, status_info in self.components.items():
|
||||||
|
result[comp_id] = {
|
||||||
|
"status": status_info.status,
|
||||||
|
"duration": status_info.get_status_duration(),
|
||||||
|
"timeout_threshold": self.status_timeouts.get(status_info.status, 300),
|
||||||
|
"last_update": status_info.last_update,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def set_check_interval(self, interval_seconds: float):
|
||||||
|
self.check_interval = interval_seconds
|
||||||
|
self._record(f"Set daemon check interval to {interval_seconds}s")
|
||||||
|
|
||||||
|
def _record(self, info):
|
||||||
|
if hasattr(self, "logger") and self.logger is not None:
|
||||||
|
self.logger.info(f"[STATUS MONITOR]: {info}")
|
||||||
|
else:
|
||||||
|
print(f"[STATUS MONITOR]: {info}")
|
||||||
21
nimbus/daemon/status_reporter.py
Normal file
21
nimbus/daemon/status_reporter.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from .status import ComponentStatus, StatusInfo
|
||||||
|
from .status_monitor import StatusMonitor
|
||||||
|
|
||||||
|
|
||||||
|
class StatusReporter:
|
||||||
|
def __init__(self, component_id: str):
|
||||||
|
self.component_id = component_id
|
||||||
|
self._status_info = StatusInfo(component_id, ComponentStatus.IDLE)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def update_status(self, status: ComponentStatus):
|
||||||
|
with self._lock:
|
||||||
|
self._status_info = StatusInfo(component_id=self.component_id, status=status, last_update=time.time())
|
||||||
|
StatusMonitor.get_instance().register_update(self._status_info)
|
||||||
|
|
||||||
|
def get_status(self) -> StatusInfo:
|
||||||
|
with self._lock:
|
||||||
|
return self._status_info
|
||||||
66
nimbus/data_engine.py
Normal file
66
nimbus/data_engine.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from time import time
|
||||||
|
|
||||||
|
from nimbus.dist_sim.head_node import HeadNode
|
||||||
|
from nimbus.scheduler.sches import gen_pipe, gen_scheduler
|
||||||
|
from nimbus.utils.logging import configure_logging
|
||||||
|
from nimbus.utils.random import set_all_seeds
|
||||||
|
from nimbus.utils.types import (
|
||||||
|
NAME,
|
||||||
|
SAFE_THRESHOLD,
|
||||||
|
STAGE_PIPE,
|
||||||
|
WORKER_SCHEDULE,
|
||||||
|
StageInput,
|
||||||
|
)
|
||||||
|
from nimbus.utils.utils import consume_stage
|
||||||
|
|
||||||
|
|
||||||
|
class DataEngine:
|
||||||
|
def __init__(self, config, master_seed=None):
|
||||||
|
if master_seed is not None:
|
||||||
|
master_seed = int(master_seed)
|
||||||
|
set_all_seeds(master_seed)
|
||||||
|
exp_name = config[NAME]
|
||||||
|
configure_logging(exp_name, config=config)
|
||||||
|
self._sche_list = gen_scheduler(config)
|
||||||
|
self._stage_input = StageInput()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
for stage in self._sche_list:
|
||||||
|
self._stage_input = stage.run(self._stage_input)
|
||||||
|
consume_stage(self._stage_input)
|
||||||
|
|
||||||
|
|
||||||
|
class DistPipeDataEngine:
|
||||||
|
def __init__(self, config, master_seed=None):
|
||||||
|
self._sche_list = gen_scheduler(config)
|
||||||
|
self.config = config
|
||||||
|
self._stage_input = StageInput()
|
||||||
|
exp_name = config[NAME]
|
||||||
|
self.logger = configure_logging(exp_name, config=config)
|
||||||
|
master_seed = int(master_seed) if master_seed is not None else None
|
||||||
|
self.pipe_list = gen_pipe(config, self._sche_list, exp_name, master_seed=master_seed)
|
||||||
|
self.head_nodes = {}
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.logger.info("[DistPipeDataEngine]: %s", self.pipe_list)
|
||||||
|
st_time = time()
|
||||||
|
cur_pipe_queue = None
|
||||||
|
pre_worker_num = 0
|
||||||
|
worker_schedule = self.config[STAGE_PIPE].get(WORKER_SCHEDULE, False)
|
||||||
|
for idx, pipe in enumerate(self.pipe_list):
|
||||||
|
self.head_nodes[idx] = HeadNode(
|
||||||
|
cur_pipe_queue,
|
||||||
|
pipe,
|
||||||
|
pre_worker_num,
|
||||||
|
self.config[STAGE_PIPE][SAFE_THRESHOLD],
|
||||||
|
worker_schedule,
|
||||||
|
self.logger,
|
||||||
|
idx,
|
||||||
|
)
|
||||||
|
self.head_nodes[idx].run()
|
||||||
|
cur_pipe_queue = self.head_nodes[idx].result_queue()
|
||||||
|
pre_worker_num = len(pipe)
|
||||||
|
for _, value in self.head_nodes.items():
|
||||||
|
value.wait_stop()
|
||||||
|
et_time = time()
|
||||||
|
self.logger.info("execution duration: %s", et_time - st_time)
|
||||||
0
nimbus/dist_sim/__init__.py
Normal file
0
nimbus/dist_sim/__init__.py
Normal file
201
nimbus/dist_sim/head_node.py
Normal file
201
nimbus/dist_sim/head_node.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import traceback
|
||||||
|
from threading import Thread
|
||||||
|
from time import sleep, time
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.util.queue import Queue
|
||||||
|
|
||||||
|
from nimbus.components.data.package import Package
|
||||||
|
from nimbus.dist_sim.task_board import TaskBoard
|
||||||
|
from nimbus.scheduler.inner_pipe import PipeWorkerGroup
|
||||||
|
|
||||||
|
|
||||||
|
class HeadNode:
|
||||||
|
def __init__(
|
||||||
|
self, data_queue, workers: PipeWorkerGroup, pre_worker_num, safe_threshold, worker_schedule, logger, idx
|
||||||
|
):
|
||||||
|
self.idx = idx
|
||||||
|
self.data_queue = data_queue
|
||||||
|
self.logger = logger
|
||||||
|
self.worker_group = workers
|
||||||
|
logger.info(f"workers: {list(workers.keys())}")
|
||||||
|
self.pre_worker_num = pre_worker_num
|
||||||
|
self.safe_threshold = safe_threshold
|
||||||
|
self.worker_schedule = worker_schedule
|
||||||
|
logger.info(f"safe_threshold: {self.safe_threshold}")
|
||||||
|
logger.info(f"worker_schedule: {self.worker_schedule}")
|
||||||
|
self.task_queue = Queue() if data_queue is not None else None
|
||||||
|
self.output_queue = Queue()
|
||||||
|
self.GEN_STOP_SIG = False
|
||||||
|
self.task_board = TaskBoard()
|
||||||
|
self.gen_thread = Thread(target=self.gen_tasks, args=())
|
||||||
|
self.gen_thread.start()
|
||||||
|
self.should_stop = False
|
||||||
|
self.run_thread = None
|
||||||
|
# Map runner ObjectRef to worker name for proper cleanup
|
||||||
|
self.runner_to_worker = {}
|
||||||
|
self.all_workers_spawned = False
|
||||||
|
|
||||||
|
def gen_tasks(self):
|
||||||
|
self.logger.info(f"headnode: {self.idx}: =============start gen task=============")
|
||||||
|
pre_worker_stop_num = 0
|
||||||
|
while not self.GEN_STOP_SIG:
|
||||||
|
if self.data_queue is None:
|
||||||
|
self.logger.info(f"headnode: {self.idx}: =============Gen Tasks stop==============")
|
||||||
|
self.all_workers_spawned = True
|
||||||
|
return
|
||||||
|
if self.data_queue.empty():
|
||||||
|
sleep(0)
|
||||||
|
continue
|
||||||
|
if self.task_queue is not None and self.task_queue.size() >= self.safe_threshold:
|
||||||
|
sleep(1)
|
||||||
|
continue
|
||||||
|
task = self.data_queue.get()
|
||||||
|
assert isinstance(
|
||||||
|
task, Package
|
||||||
|
), f"the transfered type of data should be Package type, but it is {type(task)}"
|
||||||
|
if task.should_stop():
|
||||||
|
pre_worker_stop_num += 1
|
||||||
|
self.logger.info(
|
||||||
|
f"headnode: {self.idx}: Received stop signal from upstream worker"
|
||||||
|
f" ({pre_worker_stop_num}/{self.pre_worker_num})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dynamic worker scheduling: spawn new worker when upstream worker finishes
|
||||||
|
if self.worker_schedule:
|
||||||
|
self.logger.info(
|
||||||
|
f"headnode: {self.idx}: Worker schedule enabled, will spawn 1 new worker after resource release"
|
||||||
|
)
|
||||||
|
# Wait for upstream resources to be released by upstream HeadNode's wait_stop()
|
||||||
|
# Retry mechanism to handle resource release timing
|
||||||
|
max_retries = 30 # 30 * 2s = 60s max wait
|
||||||
|
retry_interval = 2
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
self.logger.info(
|
||||||
|
f"headnode: {self.idx}: Attempting to spawn new worker (attempt"
|
||||||
|
f" {retry + 1}/{max_retries})..."
|
||||||
|
)
|
||||||
|
created_workers = self.worker_group.spawn(1)
|
||||||
|
if created_workers:
|
||||||
|
for worker_name, worker_bundle in created_workers:
|
||||||
|
# Start the new worker
|
||||||
|
runner = worker_bundle["worker"].run.remote(self.task_queue, self.output_queue)
|
||||||
|
self.runner_to_worker[runner] = worker_name
|
||||||
|
self.logger.info(
|
||||||
|
f"headnode: {self.idx}: Successfully spawned and started new worker:"
|
||||||
|
f" {worker_name}"
|
||||||
|
)
|
||||||
|
sleep(5)
|
||||||
|
break # Success, exit retry loop
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
self.logger.warning(
|
||||||
|
f"headnode: {self.idx}: Failed to spawn worker (attempt {retry + 1}), will retry in"
|
||||||
|
f" {retry_interval}s: {e}"
|
||||||
|
)
|
||||||
|
sleep(retry_interval)
|
||||||
|
else:
|
||||||
|
self.logger.error(
|
||||||
|
f"headnode: {self.idx}: Failed to spawn new worker after"
|
||||||
|
f" {max_retries} attempts: {e}"
|
||||||
|
)
|
||||||
|
self.logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if pre_worker_stop_num == self.pre_worker_num:
|
||||||
|
for _ in range(len(self.worker_group)):
|
||||||
|
self.logger.info(f"headnode: {self.idx}: get stop signal")
|
||||||
|
stop_pack = Package(None, stop_sig=True)
|
||||||
|
self.task_board.reg_task(stop_pack)
|
||||||
|
self.all_workers_spawned = True
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.task_board.reg_task(task)
|
||||||
|
if self.data_queue and not self.data_queue.empty():
|
||||||
|
task = self.data_queue.get_nowait()
|
||||||
|
self.task_board.reg_task(task)
|
||||||
|
self.logger.info("=============Gen Tasks stop==============")
|
||||||
|
self.all_workers_spawned = True
|
||||||
|
|
||||||
|
def result_queue(self):
|
||||||
|
return self.output_queue
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.logger.info(f"headnode: {self.idx}: ==============Running Head Node================")
|
||||||
|
for worker_name, worker_bundle in self.worker_group.items():
|
||||||
|
runner = worker_bundle["worker"].run.remote(self.task_queue, self.output_queue)
|
||||||
|
self.runner_to_worker[runner] = worker_name
|
||||||
|
sleep(5)
|
||||||
|
|
||||||
|
def inner_run():
|
||||||
|
while not self.should_stop:
|
||||||
|
tasks = self.task_board.get_tasks(timeout=0.05)
|
||||||
|
if len(tasks) == 0:
|
||||||
|
sleep(0)
|
||||||
|
continue
|
||||||
|
while self.task_queue.size() >= self.safe_threshold and not self.should_stop:
|
||||||
|
sleep(1)
|
||||||
|
for _, task in enumerate(tasks):
|
||||||
|
self.task_queue.put(task)
|
||||||
|
|
||||||
|
self.run_thread = Thread(target=inner_run)
|
||||||
|
self.run_thread.start()
|
||||||
|
|
||||||
|
def sig_stop(self):
|
||||||
|
self.logger.info(f"headnode: {self.idx}: ============Gen Stop===============")
|
||||||
|
self.GEN_STOP_SIG = True
|
||||||
|
self.gen_thread.join()
|
||||||
|
|
||||||
|
def wait_stop(self):
|
||||||
|
if self.worker_schedule and self.idx != 0:
|
||||||
|
self.logger.info(f"headnode: {self.idx}: Waiting for all worker spawning to complete...")
|
||||||
|
timeout = 600 # 600 seconds timeout
|
||||||
|
start_time = time()
|
||||||
|
while not self.all_workers_spawned:
|
||||||
|
if time() - start_time > timeout:
|
||||||
|
self.logger.warning(
|
||||||
|
f"headnode: {self.idx}: Timeout waiting for worker spawning completion after {timeout}s"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
sleep(0.1)
|
||||||
|
|
||||||
|
if self.all_workers_spawned:
|
||||||
|
self.logger.info(f"headnode: {self.idx}: All worker spawning completed, proceeding to wait for runners")
|
||||||
|
|
||||||
|
remaining_runners = list(self.runner_to_worker.keys())
|
||||||
|
for runner in remaining_runners:
|
||||||
|
self.logger.info(f"headnode: {self.idx}: remaining runner include: {self.runner_to_worker[runner]}")
|
||||||
|
|
||||||
|
while remaining_runners:
|
||||||
|
ready, _ = ray.wait(remaining_runners, num_returns=len(remaining_runners), timeout=1.0)
|
||||||
|
|
||||||
|
for finished_runner in ready:
|
||||||
|
worker_name = self.runner_to_worker.get(finished_runner, "unknown")
|
||||||
|
self.logger.info(f"headnode: {self.idx}: Worker {worker_name} finished")
|
||||||
|
try:
|
||||||
|
ray.get(finished_runner)
|
||||||
|
self.logger.info(f"headnode: {self.idx}: Worker {worker_name} completed successfully")
|
||||||
|
self.worker_group.remove(worker_name, self.logger)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Worker {worker_name} failed, error stack:")
|
||||||
|
self.logger.error(e)
|
||||||
|
if worker_name in self.worker_group.keys():
|
||||||
|
self.worker_group.remove(worker_name, self.logger)
|
||||||
|
|
||||||
|
remaining_runners.remove(finished_runner)
|
||||||
|
self.runner_to_worker.pop(finished_runner, None)
|
||||||
|
|
||||||
|
if not ready:
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
|
self.logger.info(f"headnode: {self.idx}: ==============stop head================")
|
||||||
|
self.should_stop = True
|
||||||
|
if self.run_thread is not None:
|
||||||
|
self.run_thread.join()
|
||||||
|
self.sig_stop()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.task_queue is not None:
|
||||||
|
self.task_queue.shutdown()
|
||||||
|
self.output_queue.shutdown()
|
||||||
42
nimbus/dist_sim/task_board.py
Normal file
42
nimbus/dist_sim/task_board.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import time
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
|
||||||
|
class Task:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_state(self, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TaskBoard:
|
||||||
|
def __init__(self):
|
||||||
|
self.tasks = []
|
||||||
|
self.flying_tasks = []
|
||||||
|
self.finished_tasks = []
|
||||||
|
self.task_cnt = 0
|
||||||
|
self.task_lock = Lock()
|
||||||
|
self.flying_task_lock = Lock()
|
||||||
|
|
||||||
|
def reg_task(self, task):
|
||||||
|
with self.task_lock:
|
||||||
|
self.tasks.append(task)
|
||||||
|
self.task_cnt += 1
|
||||||
|
|
||||||
|
def get_tasks(self, timeout=0):
|
||||||
|
st_time = time.time()
|
||||||
|
while len(self.tasks) == 0:
|
||||||
|
if time.time() - st_time > timeout:
|
||||||
|
return []
|
||||||
|
pass
|
||||||
|
with self.task_lock:
|
||||||
|
tasks = self.tasks.copy()
|
||||||
|
self.tasks = []
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
def commit_task(self, tasks):
|
||||||
|
raise NotImplementedError("commit_task not implemented")
|
||||||
|
|
||||||
|
def finished(self):
|
||||||
|
raise NotImplementedError("finished not implemented")
|
||||||
0
nimbus/scheduler/__init__.py
Normal file
0
nimbus/scheduler/__init__.py
Normal file
277
nimbus/scheduler/inner_pipe.py
Normal file
277
nimbus/scheduler/inner_pipe.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
from nimbus.daemon.status_monitor import StatusMonitor
|
||||||
|
from nimbus.scheduler.stages import DedumpStage, DumpStage
|
||||||
|
from nimbus.utils.logging import configure_logging
|
||||||
|
from nimbus.utils.random import set_all_seeds
|
||||||
|
from nimbus.utils.types import MONITOR_CHECK_INTERVAL, STATUS_TIMEOUTS, StageInput
|
||||||
|
from nimbus.utils.utils import init_env, pipe_consume_stage
|
||||||
|
|
||||||
|
|
||||||
|
def iter_to_obj(iter_obj):
|
||||||
|
return pipe_consume_stage(iter_obj), True
|
||||||
|
|
||||||
|
|
||||||
|
def _consume_N(iter_obj, N=1):
|
||||||
|
print("consume: ", iter_obj)
|
||||||
|
results = []
|
||||||
|
finish = False
|
||||||
|
for _ in range(N):
|
||||||
|
try:
|
||||||
|
obj = next(iter_obj)
|
||||||
|
results.append(obj)
|
||||||
|
except StopIteration:
|
||||||
|
finish = True
|
||||||
|
return results, finish
|
||||||
|
|
||||||
|
|
||||||
|
def consume_N(stage_input):
|
||||||
|
finish = False
|
||||||
|
if hasattr(stage_input, "Args"):
|
||||||
|
stage_input.Args, finish = _consume_N(stage_input.Args[0])
|
||||||
|
if hasattr(stage_input, "Kwargs"):
|
||||||
|
if stage_input.Kwargs is not None:
|
||||||
|
stage_input.Kwargs = {key: _consume_N(value) for key, value in stage_input.Kwargs.items()}
|
||||||
|
return stage_input, finish
|
||||||
|
|
||||||
|
|
||||||
|
class PipeWorkerGroup:
|
||||||
|
"""
|
||||||
|
Manages a group of pipe workers and their supervisors.
|
||||||
|
Supports dynamic worker spawning for worker_schedule feature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pipe_name,
|
||||||
|
exp_name,
|
||||||
|
pipe_num,
|
||||||
|
stage_list,
|
||||||
|
master_seed,
|
||||||
|
supervisor_class,
|
||||||
|
inner_pipe_class,
|
||||||
|
initial_instances=0,
|
||||||
|
):
|
||||||
|
self.workers = {}
|
||||||
|
self._next_worker_idx = 0
|
||||||
|
self.pipe_name = pipe_name
|
||||||
|
self.exp_name = exp_name
|
||||||
|
self.pipe_num = pipe_num
|
||||||
|
self.stage_list = stage_list
|
||||||
|
self.master_seed = master_seed
|
||||||
|
self.supervisor_class = supervisor_class
|
||||||
|
self.inner_pipe_class = inner_pipe_class
|
||||||
|
|
||||||
|
if initial_instances > 0:
|
||||||
|
self.spawn(initial_instances)
|
||||||
|
|
||||||
|
def spawn(self, count):
|
||||||
|
"""
|
||||||
|
Spawn new workers dynamically.
|
||||||
|
Returns list of (name, bundle) tuples for created workers.
|
||||||
|
"""
|
||||||
|
created = []
|
||||||
|
for _ in range(count):
|
||||||
|
name = f"p{self.pipe_num}_w{self._next_worker_idx}"
|
||||||
|
worker_seed = self.master_seed + self._next_worker_idx if self.master_seed is not None else None
|
||||||
|
supervisor = self.supervisor_class.remote(name)
|
||||||
|
pipe_actor = self.inner_pipe_class.remote(self.stage_list, name, supervisor, seed=worker_seed)
|
||||||
|
ray.get(supervisor.set_pipe.remote(pipe_actor))
|
||||||
|
supervisor.run.remote()
|
||||||
|
bundle = {"worker": pipe_actor, "supervisor": supervisor}
|
||||||
|
self.workers[name] = bundle
|
||||||
|
created.append((name, bundle))
|
||||||
|
self._next_worker_idx += 1
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
if created:
|
||||||
|
print(f"{self.pipe_name}: spawned {len(created)} workers - {[name for name, _ in created]}")
|
||||||
|
return created
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
"""Return items view of workers dictionary."""
|
||||||
|
return self.workers.items()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
"""Return values view of workers dictionary."""
|
||||||
|
return self.workers.values()
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
"""Return keys view of workers dictionary."""
|
||||||
|
return self.workers.keys()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Return number of workers in the group."""
|
||||||
|
return len(self.workers)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
worker_names = list(self.workers.keys())
|
||||||
|
return f"PipeWorkerGroup({worker_names})"
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Support dictionary-style access."""
|
||||||
|
return self.workers[key]
|
||||||
|
|
||||||
|
def remove(self, name, logger):
|
||||||
|
"""Remove a worker from the group."""
|
||||||
|
ray.kill(self.workers[name]["worker"])
|
||||||
|
logger.info(f"killed worker actor {name} to release GPU resouces")
|
||||||
|
ray.kill(self.workers[name]["supervisor"])
|
||||||
|
logger.info(f"Supervisor {name} killed successfully")
|
||||||
|
if name in self.workers:
|
||||||
|
del self.workers[name]
|
||||||
|
|
||||||
|
|
||||||
|
def make_pipe(pipe_name, exp_name, pipe_num, stage_list, dev, instance_num, total_processes, config, master_seed=None):
|
||||||
|
gpu_num = 0
|
||||||
|
if dev == "gpu":
|
||||||
|
resources = ray.cluster_resources()
|
||||||
|
total_gpus = resources.get("GPU", 0)
|
||||||
|
assert total_gpus > 0, "not enough gpu resources"
|
||||||
|
processes_per_gpu = math.ceil(total_processes / total_gpus)
|
||||||
|
gpu_num = 1.0 / processes_per_gpu
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class Supervisor:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = "supervisor_" + name
|
||||||
|
self.pipe_worker = None
|
||||||
|
self.logger = configure_logging(exp_name, self.name)
|
||||||
|
self.logger.info("Supervisor started")
|
||||||
|
self.monitor = StatusMonitor.get_instance()
|
||||||
|
self.monitor.set_logger(self.logger)
|
||||||
|
|
||||||
|
self._last_status_check = 0.0
|
||||||
|
self.check_interval = config.get(MONITOR_CHECK_INTERVAL, 120)
|
||||||
|
self.logger.info(f"Monitor check interval: {self.check_interval} seconds")
|
||||||
|
if config.get(STATUS_TIMEOUTS, None) is not None:
|
||||||
|
self.monitor.set_component_timeouts(config[STATUS_TIMEOUTS])
|
||||||
|
|
||||||
|
def set_pipe(self, pipe_worker):
|
||||||
|
self.logger.info("set pipe worker")
|
||||||
|
self.pipe_worker = pipe_worker
|
||||||
|
|
||||||
|
def set_queue(self, input_queue, output_queue):
|
||||||
|
self.input_queue = input_queue
|
||||||
|
self.output_queue = output_queue
|
||||||
|
|
||||||
|
def _restart_worker(self):
|
||||||
|
try:
|
||||||
|
ray.kill(self.pipe_worker, no_restart=False)
|
||||||
|
self.logger.info("trigger restart of the actor")
|
||||||
|
except Exception as ke:
|
||||||
|
self.logger.error(f"restart actor error: {ke}")
|
||||||
|
|
||||||
|
def update_component_state(self, components_state):
|
||||||
|
for _, state in components_state.items():
|
||||||
|
self.monitor.register_update(state)
|
||||||
|
|
||||||
|
def _start_daemon(self):
|
||||||
|
miss_cnt = 0
|
||||||
|
while True:
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_status_check >= self.check_interval:
|
||||||
|
try:
|
||||||
|
timeout_components = self.monitor.check_and_update_timeouts()
|
||||||
|
if len(timeout_components) > 0:
|
||||||
|
self.logger.warning(f"Components timeout: {timeout_components}, restart the pipe worker")
|
||||||
|
self._restart_worker()
|
||||||
|
self.monitor.clear()
|
||||||
|
else:
|
||||||
|
if self.monitor.get_components_length() == 0:
|
||||||
|
miss_cnt += 1
|
||||||
|
self.logger.info(f"No components timeout detected, miss count: {miss_cnt}")
|
||||||
|
if miss_cnt >= 5:
|
||||||
|
self.logger.info("No components detected for 5 consecutive checks, restart pipe worker")
|
||||||
|
self._restart_worker()
|
||||||
|
self.monitor.clear()
|
||||||
|
miss_cnt = 0
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Get components status failed: {e}")
|
||||||
|
self._restart_worker()
|
||||||
|
self.monitor.clear()
|
||||||
|
self._last_status_check = now
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
assert self.pipe_worker is not None, "pipe worker is not set"
|
||||||
|
thread = threading.Thread(target=self._start_daemon, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=gpu_num, max_restarts=3, max_task_retries=3)
|
||||||
|
class InnerPipe:
|
||||||
|
def __init__(self, stage_list, name, supervisor, seed=None):
|
||||||
|
if seed is not None:
|
||||||
|
set_all_seeds(seed)
|
||||||
|
self.stages = stage_list
|
||||||
|
self.name = name
|
||||||
|
self.supervisor = supervisor
|
||||||
|
init_env()
|
||||||
|
self.logger = configure_logging(exp_name, self.name)
|
||||||
|
self.logger.info(f"Working on gpu {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
||||||
|
if ray.get_runtime_context().was_current_actor_reconstructed is True:
|
||||||
|
msg = (
|
||||||
|
f"{'='*80}\n"
|
||||||
|
"!!! ATTENTION !!!\n"
|
||||||
|
f"!!! InnerPipe {name} WAS RECONSTRUCTED due to SYSTEM ERROR !!!\n"
|
||||||
|
"!!! Please CHECK LOGS in /tmp/ray/session_latest/logs/ for details !!!\n"
|
||||||
|
f"{'='*80}\n"
|
||||||
|
)
|
||||||
|
self.logger.info(msg)
|
||||||
|
|
||||||
|
self.monitor = StatusMonitor.get_instance()
|
||||||
|
self.monitor.set_logger(self.logger)
|
||||||
|
|
||||||
|
self.monitor_check_interval = config.get(MONITOR_CHECK_INTERVAL, 120)
|
||||||
|
|
||||||
|
def _update_supervisor(self):
|
||||||
|
while True:
|
||||||
|
for _ in range(self.monitor_check_interval):
|
||||||
|
time.sleep(1)
|
||||||
|
components_status = self.monitor.get_all_status()
|
||||||
|
ray.get(self.supervisor.update_component_state.remote(components_status))
|
||||||
|
|
||||||
|
def run(self, input_queue, output_queue):
|
||||||
|
self.logger.info(f"[InnerPipe stages]: {self.stages}")
|
||||||
|
|
||||||
|
thread = threading.Thread(target=self._update_supervisor, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
self.logger.info("Reporter started, start running pipe")
|
||||||
|
|
||||||
|
mid_results = StageInput()
|
||||||
|
# if input_queue is None:
|
||||||
|
# mid_results = StageInput()
|
||||||
|
# else:
|
||||||
|
# mid_results = StageInput((input_queue,), {})
|
||||||
|
for _, stage in enumerate(self.stages):
|
||||||
|
if isinstance(stage, DumpStage):
|
||||||
|
mid_results = stage.run(mid_results, output_queue)
|
||||||
|
elif isinstance(stage, DedumpStage):
|
||||||
|
mid_results = stage.run(mid_results, input_queue)
|
||||||
|
else:
|
||||||
|
mid_results = stage.run(mid_results)
|
||||||
|
result, finish = iter_to_obj(mid_results)
|
||||||
|
self.logger.info("====================================")
|
||||||
|
self.logger.info(f"result: {result}, finish: {finish}")
|
||||||
|
self.logger.info("====================================")
|
||||||
|
ray.kill(self.supervisor)
|
||||||
|
self.logger.info("actor finished")
|
||||||
|
return finish
|
||||||
|
|
||||||
|
group = PipeWorkerGroup(
|
||||||
|
pipe_name=pipe_name,
|
||||||
|
exp_name=exp_name,
|
||||||
|
pipe_num=pipe_num,
|
||||||
|
stage_list=stage_list,
|
||||||
|
master_seed=master_seed,
|
||||||
|
supervisor_class=Supervisor,
|
||||||
|
inner_pipe_class=InnerPipe,
|
||||||
|
initial_instances=instance_num,
|
||||||
|
)
|
||||||
|
print(pipe_name, group)
|
||||||
|
return group
|
||||||
115
nimbus/scheduler/instructions.py
Normal file
115
nimbus/scheduler/instructions.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from nimbus.components.dedump import dedumper_dict
|
||||||
|
from nimbus.components.dump import dumper_dict
|
||||||
|
from nimbus.components.load import layout_randomizer_dict, scene_loader_dict
|
||||||
|
from nimbus.components.plan_with_render import plan_with_render_dict
|
||||||
|
from nimbus.components.planner import seq_planner_dict
|
||||||
|
from nimbus.components.render import renderer_dict
|
||||||
|
from nimbus.components.store import writer_dict
|
||||||
|
from nimbus.utils.types import ARGS, PLANNER, TYPE
|
||||||
|
|
||||||
|
|
||||||
|
class Instruction:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, stage_input):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadSceneInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.scene_iter = scene_loader_dict[self.config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input):
|
||||||
|
pack_iter = pack_iter = stage_input.Args[0] if stage_input.Args is not None else None
|
||||||
|
return self.scene_iter(pack_iter=pack_iter, **self.config.get(ARGS, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomizeLayoutInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.layout_randomlizer = layout_randomizer_dict[self.config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input):
|
||||||
|
scene_iterator = stage_input.Args[0]
|
||||||
|
extend_scene_iterator = self.layout_randomlizer(scene_iterator, **self.config.get(ARGS, {}))
|
||||||
|
return extend_scene_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class PlanPathInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.seq_planner = seq_planner_dict[self.config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input):
|
||||||
|
scene_iter = stage_input.Args[0]
|
||||||
|
planner_cfg = self.config[PLANNER] if PLANNER in self.config else None
|
||||||
|
return self.seq_planner(scene_iter, planner_cfg, **self.config.get(ARGS, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class RenderInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.renderer = renderer_dict[self.config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input):
|
||||||
|
scene_seqs_iter = stage_input.Args[0]
|
||||||
|
obs_iter = self.renderer(scene_seqs_iter, **self.config.get(ARGS, {}))
|
||||||
|
return obs_iter
|
||||||
|
|
||||||
|
|
||||||
|
class PlanWithRenderInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.plan_with_render = plan_with_render_dict[config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input):
|
||||||
|
scene_iter = stage_input.Args[0]
|
||||||
|
plan_with_render_iter = self.plan_with_render(scene_iter, **self.config.get(ARGS, {}))
|
||||||
|
return plan_with_render_iter
|
||||||
|
|
||||||
|
|
||||||
|
class StoreInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.writer = writer_dict[config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input):
|
||||||
|
seqs_obs_iter = stage_input.Args[0]
|
||||||
|
store_iter = self.writer(seqs_obs_iter, **self.config.get(ARGS, {}))
|
||||||
|
return store_iter
|
||||||
|
|
||||||
|
|
||||||
|
class DumpInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.dumper = dumper_dict[config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input, output_queue=None):
|
||||||
|
seqs_obs_iter = stage_input.Args[0]
|
||||||
|
dump_iter = self.dumper(seqs_obs_iter, output_queue=output_queue, **self.config.get(ARGS, {}))
|
||||||
|
return dump_iter
|
||||||
|
|
||||||
|
|
||||||
|
class DeDumpInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.dedumper = dedumper_dict[config[TYPE]]
|
||||||
|
|
||||||
|
def run(self, stage_input, input_queue=None):
|
||||||
|
dump_iter = self.dedumper(input_queue=input_queue, **self.config.get(ARGS, {}))
|
||||||
|
return dump_iter
|
||||||
|
|
||||||
|
|
||||||
|
class ComposeInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotateDataInstruction(Instruction):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
80
nimbus/scheduler/sches.py
Normal file
80
nimbus/scheduler/sches.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from nimbus.scheduler.inner_pipe import make_pipe
|
||||||
|
from nimbus.scheduler.stages import (
|
||||||
|
DedumpStage,
|
||||||
|
DumpStage,
|
||||||
|
LoadStage,
|
||||||
|
PlanStage,
|
||||||
|
PlanWithRenderStage,
|
||||||
|
RenderStage,
|
||||||
|
StoreStage,
|
||||||
|
)
|
||||||
|
from nimbus.utils.types import (
|
||||||
|
DEDUMP_STAGE,
|
||||||
|
DUMP_STAGE,
|
||||||
|
LOAD_STAGE,
|
||||||
|
PLAN_STAGE,
|
||||||
|
PLAN_WITH_RENDER_STAGE,
|
||||||
|
RENDER_STAGE,
|
||||||
|
STAGE_DEV,
|
||||||
|
STAGE_NUM,
|
||||||
|
STAGE_PIPE,
|
||||||
|
STORE_STAGE,
|
||||||
|
WORKER_NUM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_scheduler(config):
|
||||||
|
stages = []
|
||||||
|
if LOAD_STAGE in config:
|
||||||
|
stages.append(LoadStage(config[LOAD_STAGE]))
|
||||||
|
if PLAN_WITH_RENDER_STAGE in config:
|
||||||
|
stages.append(PlanWithRenderStage(config[PLAN_WITH_RENDER_STAGE]))
|
||||||
|
if PLAN_STAGE in config:
|
||||||
|
stages.append(PlanStage(config[PLAN_STAGE]))
|
||||||
|
if DUMP_STAGE in config:
|
||||||
|
stages.append(DumpStage(config[DUMP_STAGE]))
|
||||||
|
if DEDUMP_STAGE in config:
|
||||||
|
stages.append(DedumpStage(config[DEDUMP_STAGE]))
|
||||||
|
if RENDER_STAGE in config:
|
||||||
|
stages.append(RenderStage(config[RENDER_STAGE]))
|
||||||
|
if STORE_STAGE in config:
|
||||||
|
stages.append(StoreStage(config[STORE_STAGE]))
|
||||||
|
return stages
|
||||||
|
|
||||||
|
|
||||||
|
def gen_pipe(config, stage_list, exp_name, master_seed=None):
|
||||||
|
if STAGE_PIPE in config:
|
||||||
|
pipe_stages_num = config[STAGE_PIPE][STAGE_NUM]
|
||||||
|
pipe_stages_dev = config[STAGE_PIPE][STAGE_DEV]
|
||||||
|
pipe_worker_num = config[STAGE_PIPE][WORKER_NUM]
|
||||||
|
inner_pipes = []
|
||||||
|
pipe_num = 0
|
||||||
|
total_processes = 0
|
||||||
|
for worker_num in config[STAGE_PIPE][WORKER_NUM]:
|
||||||
|
total_processes += worker_num
|
||||||
|
for num, dev, worker_num in zip(pipe_stages_num, pipe_stages_dev, pipe_worker_num):
|
||||||
|
stages = stage_list[:num]
|
||||||
|
print("===========================")
|
||||||
|
print(f"inner stage num: {num}, device type: {dev}")
|
||||||
|
print(f"stages: {stages}")
|
||||||
|
print("===========================")
|
||||||
|
stage_list = stage_list[num:]
|
||||||
|
pipe_name = "pipe"
|
||||||
|
for stage in stages:
|
||||||
|
pipe_name += f"_{stage.__class__.__name__}"
|
||||||
|
pipe_workers = make_pipe(
|
||||||
|
pipe_name,
|
||||||
|
exp_name,
|
||||||
|
pipe_num,
|
||||||
|
stages,
|
||||||
|
dev,
|
||||||
|
worker_num,
|
||||||
|
total_processes,
|
||||||
|
config[STAGE_PIPE],
|
||||||
|
master_seed=master_seed,
|
||||||
|
)
|
||||||
|
inner_pipes.append(pipe_workers)
|
||||||
|
pipe_num += 1
|
||||||
|
return inner_pipes
|
||||||
|
else:
|
||||||
|
return [make_pipe.InnerPipe(stage_list)]
|
||||||
137
nimbus/scheduler/stages.py
Normal file
137
nimbus/scheduler/stages.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from nimbus.scheduler.instructions import (
|
||||||
|
DeDumpInstruction,
|
||||||
|
DumpInstruction,
|
||||||
|
Instruction,
|
||||||
|
LoadSceneInstruction,
|
||||||
|
PlanPathInstruction,
|
||||||
|
PlanWithRenderInstruction,
|
||||||
|
RandomizeLayoutInstruction,
|
||||||
|
RenderInstruction,
|
||||||
|
StoreInstruction,
|
||||||
|
)
|
||||||
|
from nimbus.utils.types import (
|
||||||
|
DEDUMPER,
|
||||||
|
DUMPER,
|
||||||
|
LAYOUT_RANDOM_GENERATOR,
|
||||||
|
PLAN_WITH_RENDER,
|
||||||
|
RENDERER,
|
||||||
|
SCENE_LOADER,
|
||||||
|
SEQ_PLANNER,
|
||||||
|
WRITER,
|
||||||
|
StageInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Stage:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.instructions: list[Instruction] = []
|
||||||
|
self.output_queue = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, stage_input):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
if SCENE_LOADER in config:
|
||||||
|
self.instructions.append(LoadSceneInstruction(config[SCENE_LOADER]))
|
||||||
|
if LAYOUT_RANDOM_GENERATOR in config:
|
||||||
|
self.instructions.append(RandomizeLayoutInstruction(config[LAYOUT_RANDOM_GENERATOR]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput):
|
||||||
|
for instruction in self.instructions:
|
||||||
|
scene_iterator = instruction.run(stage_input)
|
||||||
|
stage_input = StageInput((scene_iterator,), {})
|
||||||
|
return stage_input
|
||||||
|
|
||||||
|
|
||||||
|
class PlanStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
if SEQ_PLANNER in config:
|
||||||
|
self.instructions.append(PlanPathInstruction(config[SEQ_PLANNER]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput):
|
||||||
|
for instruction in self.instructions:
|
||||||
|
scene_seqs_iter = instruction.run(stage_input)
|
||||||
|
stage_input = StageInput((scene_seqs_iter,), {})
|
||||||
|
return stage_input
|
||||||
|
|
||||||
|
|
||||||
|
class RenderStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.instructions.append(RenderInstruction(config[RENDERER]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput):
|
||||||
|
for instruction in self.instructions:
|
||||||
|
obs_iter = instruction.run(stage_input)
|
||||||
|
stage_input = StageInput((obs_iter,), {})
|
||||||
|
return stage_input
|
||||||
|
|
||||||
|
|
||||||
|
class PlanWithRenderStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.instructions.append(PlanWithRenderInstruction(config[PLAN_WITH_RENDER]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput):
|
||||||
|
for instruction in self.instructions:
|
||||||
|
scene_seqs_iter = instruction.run(stage_input)
|
||||||
|
stage_input = StageInput((scene_seqs_iter,), {})
|
||||||
|
return stage_input
|
||||||
|
|
||||||
|
|
||||||
|
class StoreStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
if WRITER in config:
|
||||||
|
self.instructions.append(StoreInstruction(config[WRITER]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput):
|
||||||
|
for instruction in self.instructions:
|
||||||
|
store_iter = instruction.run(stage_input)
|
||||||
|
stage_input = StageInput((store_iter,), {})
|
||||||
|
return stage_input
|
||||||
|
|
||||||
|
|
||||||
|
class DumpStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.instructions.append(DumpInstruction(config[DUMPER]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput, output_queue=None):
|
||||||
|
for instruction in self.instructions:
|
||||||
|
dump_iter = instruction.run(stage_input, output_queue)
|
||||||
|
stage_input = StageInput((dump_iter,), {})
|
||||||
|
return stage_input
|
||||||
|
|
||||||
|
|
||||||
|
class DedumpStage(Stage):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
if DEDUMPER in config:
|
||||||
|
self.instructions.append(DeDumpInstruction(config[DEDUMPER]))
|
||||||
|
if SCENE_LOADER in config:
|
||||||
|
self.instructions.append(LoadSceneInstruction(config[SCENE_LOADER]))
|
||||||
|
if LAYOUT_RANDOM_GENERATOR in config:
|
||||||
|
self.instructions.append(RandomizeLayoutInstruction(config[LAYOUT_RANDOM_GENERATOR]))
|
||||||
|
if SEQ_PLANNER in config:
|
||||||
|
self.instructions.append(PlanPathInstruction(config[SEQ_PLANNER]))
|
||||||
|
|
||||||
|
def run(self, stage_input: StageInput, input_queue=None):
|
||||||
|
if input_queue is not None:
|
||||||
|
self.input_queue = input_queue
|
||||||
|
|
||||||
|
for instruction in self.instructions:
|
||||||
|
if isinstance(instruction, DeDumpInstruction):
|
||||||
|
result = instruction.run(stage_input, input_queue)
|
||||||
|
else:
|
||||||
|
result = instruction.run(stage_input)
|
||||||
|
stage_input = StageInput((result,), {})
|
||||||
|
return stage_input
|
||||||
20
nimbus/utils/config.py
Normal file
20
nimbus/utils/config.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(*yaml_files, cli_args=None):
|
||||||
|
if cli_args is None:
|
||||||
|
cli_args = []
|
||||||
|
yaml_confs = [OmegaConf.load(f) for f in yaml_files]
|
||||||
|
cli_conf = OmegaConf.from_cli(cli_args)
|
||||||
|
conf = OmegaConf.merge(*yaml_confs, cli_conf)
|
||||||
|
OmegaConf.resolve(conf)
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
|
def config_to_primitive(config, resolve=True):
|
||||||
|
return OmegaConf.to_container(config, resolve=resolve)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(config, path):
|
||||||
|
with open(path, "w", encoding="utf-8") as fp:
|
||||||
|
OmegaConf.save(config=config, f=fp)
|
||||||
138
nimbus/utils/config_processor.py
Normal file
138
nimbus/utils/config_processor.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""
|
||||||
|
Config Processor: Responsible for identifying, converting, and loading configuration files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from nimbus.utils.config import load_config
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigProcessor:
|
||||||
|
"""Config processor class"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _check_config_path_exists(self, config, path):
|
||||||
|
"""
|
||||||
|
Check if a configuration path exists in the config object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: OmegaConf config object
|
||||||
|
path: String path like 'stage_pipe.worker_num' or 'load_stage.scene_loader.args.random_num'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the path exists in the config
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
keys = path.split(".")
|
||||||
|
current = config
|
||||||
|
for key in keys:
|
||||||
|
if isinstance(current, DictConfig):
|
||||||
|
if key not in current:
|
||||||
|
return False
|
||||||
|
current = current[key]
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_cli_args(self, config, cli_args):
|
||||||
|
"""
|
||||||
|
Validate that all CLI arguments correspond to existing paths in the config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: OmegaConf config object
|
||||||
|
cli_args: List of command line arguments
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any CLI argument path doesn't exist in the config
|
||||||
|
"""
|
||||||
|
if not cli_args:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clean up CLI args to remove -- prefix if present
|
||||||
|
cleaned_cli_args = []
|
||||||
|
for arg in cli_args:
|
||||||
|
if arg.startswith("--"):
|
||||||
|
cleaned_cli_args.append(arg[2:]) # Remove the -- prefix
|
||||||
|
else:
|
||||||
|
cleaned_cli_args.append(arg)
|
||||||
|
|
||||||
|
# Parse CLI args to get the override paths
|
||||||
|
try:
|
||||||
|
cli_conf = OmegaConf.from_cli(cleaned_cli_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid CLI argument format: {e}. Please use format like: stage_pipe.worker_num='[2,4]'")
|
||||||
|
|
||||||
|
def check_nested_paths(conf, prefix=""):
|
||||||
|
"""Recursively check all paths in the CLI config"""
|
||||||
|
for key, value in conf.items():
|
||||||
|
current_path = f"{prefix}.{key}" if prefix else key
|
||||||
|
|
||||||
|
if isinstance(value, DictConfig):
|
||||||
|
# Check if this intermediate path exists
|
||||||
|
if not self._check_config_path_exists(config, current_path):
|
||||||
|
raise ValueError(f"Configuration path '{current_path}' does not exist in the config file")
|
||||||
|
# Recursively check nested paths
|
||||||
|
check_nested_paths(value, current_path)
|
||||||
|
else:
|
||||||
|
# Check if this leaf path exists
|
||||||
|
if not self._check_config_path_exists(config, current_path):
|
||||||
|
raise ValueError(f"Configuration path '{current_path}' does not exist in the config file")
|
||||||
|
|
||||||
|
try:
|
||||||
|
check_nested_paths(cli_conf)
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# If there's an issue parsing CLI args, provide helpful error message
|
||||||
|
raise ValueError("Invalid CLI argument format. Please use format like: --key=value or --nested.key=value")
|
||||||
|
|
||||||
|
def process_config(self, config_path, cli_args=None):
|
||||||
|
"""
|
||||||
|
Process the config file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the config file
|
||||||
|
cli_args: List of command line arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OmegaConf: Processed config object
|
||||||
|
"""
|
||||||
|
# Clean up CLI args to remove -- prefix if present
|
||||||
|
cleaned_cli_args = []
|
||||||
|
if cli_args:
|
||||||
|
for arg in cli_args:
|
||||||
|
if arg.startswith("--"):
|
||||||
|
cleaned_cli_args.append(arg[2:]) # Remove the -- prefix
|
||||||
|
else:
|
||||||
|
cleaned_cli_args.append(arg)
|
||||||
|
|
||||||
|
# Load config first without CLI args to validate paths
|
||||||
|
try:
|
||||||
|
base_config = load_config(config_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading config: {e}")
|
||||||
|
|
||||||
|
# Validate that CLI arguments correspond to existing paths
|
||||||
|
if cli_args:
|
||||||
|
self._validate_cli_args(base_config, cli_args)
|
||||||
|
|
||||||
|
# Now load config with CLI args (validation passed)
|
||||||
|
config = load_config(config_path, cli_args=cleaned_cli_args)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def print_final_config(self, config):
|
||||||
|
"""
|
||||||
|
Print the final running config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: OmegaConf config object
|
||||||
|
"""
|
||||||
|
print("=" * 50)
|
||||||
|
print("final config:")
|
||||||
|
print("=" * 50)
|
||||||
|
print(OmegaConf.to_yaml(config))
|
||||||
23
nimbus/utils/flags.py
Normal file
23
nimbus/utils/flags.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
_DEBUG_KEY = "NIMBUS_DEBUG"
|
||||||
|
_RANDOM_SEED_KEY = "NIMBUS_RANDOM_SEED"
|
||||||
|
|
||||||
|
|
||||||
|
def set_debug_mode(enabled: bool) -> None:
|
||||||
|
"""Set debug mode. Must be called before ray.init() to propagate to Ray workers."""
|
||||||
|
os.environ[_DEBUG_KEY] = "1" if enabled else "0"
|
||||||
|
|
||||||
|
|
||||||
|
def is_debug_mode() -> bool:
|
||||||
|
return os.environ.get(_DEBUG_KEY, "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed: int) -> None:
|
||||||
|
"""Set global random seed. Must be called before ray.init() to propagate to Ray workers."""
|
||||||
|
os.environ[_RANDOM_SEED_KEY] = str(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_seed() -> int | None:
|
||||||
|
val = os.environ.get(_RANDOM_SEED_KEY)
|
||||||
|
return int(val) if val is not None else None
|
||||||
48
nimbus/utils/logging.py
Normal file
48
nimbus/utils/logging.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from nimbus.utils.config import save_config
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(exp_name, name=None, config=None):
|
||||||
|
pod_name = os.environ.get("POD_NAME", None)
|
||||||
|
if pod_name is not None:
|
||||||
|
exp_name = f"{exp_name}/{pod_name}"
|
||||||
|
log_dir = os.path.join("./output", exp_name)
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
if name is None:
|
||||||
|
log_name = f"de_time_profile_{timestamp}.log"
|
||||||
|
else:
|
||||||
|
log_name = f"de_{name}_time_profile_{timestamp}.log"
|
||||||
|
|
||||||
|
log_file = os.path.join(log_dir, log_name)
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Stale file handle when creating {log_dir}, attempt {attempt + 1}/{max_retries}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(3)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to create log directory {log_dir} after {max_retries} attempts") from e
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
config_log_file = os.path.join(log_dir, "de_config.yaml")
|
||||||
|
save_config(config, config_log_file)
|
||||||
|
|
||||||
|
logger = logging.getLogger("de_logger")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
fh = logging.FileHandler(log_file, mode="a")
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
logger.addHandler(fh)
|
||||||
|
logger.info("Start Data Engine")
|
||||||
|
|
||||||
|
return logger
|
||||||
33
nimbus/utils/random.py
Normal file
33
nimbus/utils/random.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Try to import open3d, but don't fail if it's not installed
|
||||||
|
try:
|
||||||
|
import open3d as o3d
|
||||||
|
except ImportError:
|
||||||
|
o3d = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_all_seeds(seed):
|
||||||
|
"""
|
||||||
|
Sets seeds for all relevant random number generators to ensure reproducibility.
|
||||||
|
"""
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
|
print(f"set seed {seed} for all libraries")
|
||||||
|
seed = int(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
if o3d and hasattr(o3d, "utility") and hasattr(o3d.utility, "random"):
|
||||||
|
o3d.utility.random.seed(seed)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
# These settings are crucial for deterministic results with CuDNN
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
65
nimbus/utils/types.py
Normal file
65
nimbus/utils/types.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
NAME = "name"
|
||||||
|
|
||||||
|
# stage name
|
||||||
|
LOAD_STAGE = "load_stage"
|
||||||
|
PLAN_STAGE = "plan_stage"
|
||||||
|
RENDER_STAGE = "render_stage"
|
||||||
|
PLAN_WITH_RENDER_STAGE = "plan_with_render_stage"
|
||||||
|
STORE_STAGE = "store_stage"
|
||||||
|
STAGE_PIPE = "stage_pipe"
|
||||||
|
DUMP_STAGE = "dump_stage"
|
||||||
|
DEDUMP_STAGE = "dedump_stage"
|
||||||
|
|
||||||
|
# instruction name
|
||||||
|
# LOAD_STAGE
|
||||||
|
SCENE_LOADER = "scene_loader"
|
||||||
|
LAYOUT_RANDOM_GENERATOR = "layout_random_generator"
|
||||||
|
INDEX_GENERATOR = "index_generator"
|
||||||
|
DEDUMPER = "dedumper"
|
||||||
|
|
||||||
|
# PLAN_STAGE
|
||||||
|
SEQ_PLANNER = "seq_planner"
|
||||||
|
PLANNER = "planner"
|
||||||
|
SIMULATOR = "simulator"
|
||||||
|
|
||||||
|
# RENDER_STAGE
|
||||||
|
RENDERER = "renderer"
|
||||||
|
|
||||||
|
# PLAN_WITH_RENDER_STAGE
|
||||||
|
PLAN_WITH_RENDER = "plan_with_render"
|
||||||
|
|
||||||
|
# PIPE_STAGE
|
||||||
|
STAGE_NUM = "stage_num"
|
||||||
|
STAGE_DEV = "stage_dev"
|
||||||
|
WORKER_NUM = "worker_num"
|
||||||
|
WORKER_SCHEDULE = "worker_schedule"
|
||||||
|
SAFE_THRESHOLD = "safe_threshold"
|
||||||
|
STATUS_TIMEOUTS = "status_timeouts"
|
||||||
|
MONITOR_CHECK_INTERVAL = "monitor_check_interval"
|
||||||
|
|
||||||
|
# STORE_STAGE
|
||||||
|
WRITER = "writer"
|
||||||
|
DUMPER = "dumper"
|
||||||
|
|
||||||
|
OUTPUT_PATH = "output_path"
|
||||||
|
INPUT_PATH = "input_path"
|
||||||
|
|
||||||
|
TYPE = "type"
|
||||||
|
ARGS = "args"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageInput:
|
||||||
|
"""
|
||||||
|
A data class that encapsulates the input for a stage in the processing pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Args (Optional[Tuple]): Positional arguments passed to the stage's processing function.
|
||||||
|
Kwargs (Optional[Dict]): Keyword arguments passed to the stage's processing function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args: Optional[Tuple] = None
|
||||||
|
Kwargs: Optional[Dict] = None
|
||||||
182
nimbus/utils/utils.py
Normal file
182
nimbus/utils/utils.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Tuple, Type, Union
|
||||||
|
|
||||||
|
from nimbus.components.data.observation import Observations
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.data.sequence import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def init_env():
|
||||||
|
sys.path.append("./")
|
||||||
|
sys.path.append("./data_engine")
|
||||||
|
sys.path.append("workflows/simbox")
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_iter_data(data: tuple):
|
||||||
|
assert len(data) <= 3, "not support yet"
|
||||||
|
scene = None
|
||||||
|
seq = None
|
||||||
|
obs = None
|
||||||
|
for item in data:
|
||||||
|
if isinstance(item, Scene):
|
||||||
|
scene = item
|
||||||
|
elif isinstance(item, Sequence):
|
||||||
|
seq = item
|
||||||
|
elif isinstance(item, Observations):
|
||||||
|
obs = item
|
||||||
|
return scene, seq, obs
|
||||||
|
|
||||||
|
|
||||||
|
def consume_stage(stage_input):
|
||||||
|
if hasattr(stage_input, "Args"):
|
||||||
|
consume_iterators(stage_input.Args)
|
||||||
|
for value in stage_input.Args:
|
||||||
|
if hasattr(value, "__del__"):
|
||||||
|
value.__del__() # pylint: disable=C2801
|
||||||
|
if hasattr(stage_input, "Kwargs"):
|
||||||
|
if stage_input.Kwargs is not None:
|
||||||
|
for value in stage_input.Kwargs.values():
|
||||||
|
consume_iterators(value)
|
||||||
|
if hasattr(value, "__del__"):
|
||||||
|
value.__del__() # pylint: disable=C2801
|
||||||
|
|
||||||
|
|
||||||
|
# prevent isaac sim close pipe worker in advance
|
||||||
|
def pipe_consume_stage(stage_input):
|
||||||
|
if hasattr(stage_input, "Args"):
|
||||||
|
consume_iterators(stage_input.Args)
|
||||||
|
if hasattr(stage_input, "Kwargs"):
|
||||||
|
if stage_input.Kwargs is not None:
|
||||||
|
for value in stage_input.Kwargs.values():
|
||||||
|
consume_iterators(value)
|
||||||
|
|
||||||
|
|
||||||
|
def consume_iterators(obj):
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
if isinstance(obj, (str, bytes)):
|
||||||
|
return obj
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {key: consume_iterators(value) for key, value in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [consume_iterators(item) for item in obj]
|
||||||
|
if isinstance(obj, tuple):
|
||||||
|
return tuple(consume_iterators(item) for item in obj)
|
||||||
|
if hasattr(obj, "__iter__"):
|
||||||
|
for item in obj:
|
||||||
|
consume_iterators(item)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def scene_names_postprocess(scene_names: list) -> list:
|
||||||
|
"""
|
||||||
|
Distributes a list of scene names (folders) among multiple workers in a distributed environment.
|
||||||
|
This function is designed to work with Deep Learning Container (DLC) environments, where worker
|
||||||
|
information is extracted from environment variables. It assigns a subset of the input scene names
|
||||||
|
to the current worker based on its rank and the total number of workers, using a round-robin strategy.
|
||||||
|
If not running in a DLC environment, all scene names are assigned to a single worker.
|
||||||
|
Args:
|
||||||
|
scene_names (list): List of scene names (typically folder names) to be distributed.
|
||||||
|
Returns:
|
||||||
|
list: The subset of scene names assigned to the current worker.
|
||||||
|
Raises:
|
||||||
|
PermissionError: If there is a permission issue accessing the input directory.
|
||||||
|
RuntimeError: For any other errors encountered during processing.
|
||||||
|
Notes:
|
||||||
|
- The function expects certain environment variables (e.g., POD_NAME, WORLD_SIZE) to be set
|
||||||
|
in DLC environments.
|
||||||
|
- If multiple workers are present, the input list is sorted before distribution to ensure
|
||||||
|
consistent assignment across workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_dlc_worker_info():
|
||||||
|
"""Extract worker rank and world size from DLC environment variables."""
|
||||||
|
pod_name = os.environ.get("POD_NAME")
|
||||||
|
|
||||||
|
if pod_name:
|
||||||
|
# Match worker-N or master-N patterns
|
||||||
|
match = re.search(r"dlc.*?-(worker|master)-(\d+)$", pod_name)
|
||||||
|
if match:
|
||||||
|
node_type, node_id = match.groups()
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
if node_type == "worker":
|
||||||
|
rank = int(node_id)
|
||||||
|
else: # master node
|
||||||
|
rank = world_size - 1
|
||||||
|
|
||||||
|
return rank, world_size
|
||||||
|
|
||||||
|
# Default for non-DLC environment
|
||||||
|
return 0, 1
|
||||||
|
|
||||||
|
def _distribute_folders(all_folders, rank, world_size):
|
||||||
|
"""Distribute folders among workers using round-robin strategy."""
|
||||||
|
if not all_folders:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Only sort when there are multiple workers to ensure consistency
|
||||||
|
if world_size > 1:
|
||||||
|
all_folders.sort()
|
||||||
|
|
||||||
|
# Distribute using slicing: worker i gets folders at indices i, i+world_size, ...
|
||||||
|
return all_folders[rank::world_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all subfolders
|
||||||
|
all_subfolders = scene_names
|
||||||
|
if not all_subfolders:
|
||||||
|
print(f"Warning: No scene found in {scene_names}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get worker identity and distribute folders
|
||||||
|
rank, world_size = _get_dlc_worker_info()
|
||||||
|
assigned_folders = _distribute_folders(all_subfolders, rank, world_size)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"DLC Worker {rank}/{world_size}: Assigned {len(assigned_folders)} out of "
|
||||||
|
f"{len(all_subfolders)} total folders"
|
||||||
|
)
|
||||||
|
|
||||||
|
return assigned_folders
|
||||||
|
|
||||||
|
except PermissionError:
|
||||||
|
raise PermissionError(f"No permission to access directory: {scene_names}")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error reading input directory {scene_names}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_exception(
|
||||||
|
max_retries: int = 3, retry_exceptions: Union[bool, Tuple[Type[Exception], ...]] = True, delay: float = 1.0
|
||||||
|
):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
last_exception = None
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
if attempt > 0:
|
||||||
|
print(f"Retry attempt {attempt}/{max_retries} for {func.__name__}")
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
should_retry = False
|
||||||
|
if retry_exceptions is True:
|
||||||
|
should_retry = True
|
||||||
|
elif isinstance(retry_exceptions, (tuple, list)):
|
||||||
|
should_retry = isinstance(e, retry_exceptions)
|
||||||
|
|
||||||
|
if should_retry and attempt < max_retries:
|
||||||
|
print(f"Error in {func.__name__}: {e}. Retrying in {delay} seconds...")
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
14
nimbus_extension/__init__.py
Normal file
14
nimbus_extension/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
nimbus_extension — official component implementations for the nimbus framework.
|
||||||
|
|
||||||
|
Importing this package registers all built-in components into the nimbus
|
||||||
|
component registries (dumper_dict, renderer_dict, etc.) so they are
|
||||||
|
available for use in pipeline configs.
|
||||||
|
|
||||||
|
Usage in launcher.py::
|
||||||
|
|
||||||
|
import nimbus_extension # registers all components
|
||||||
|
from nimbus import run_data_engine
|
||||||
|
"""
|
||||||
|
|
||||||
|
from . import components # noqa: F401 triggers all register() calls
|
||||||
2
nimbus_extension/components/__init__.py
Normal file
2
nimbus_extension/components/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# flake8: noqa: #F401
|
||||||
|
from . import dedump, dump, load, plan_with_render, planner, render, store
|
||||||
5
nimbus_extension/components/dedump/__init__.py
Normal file
5
nimbus_extension/components/dedump/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import nimbus.components.dedump as _dedump
|
||||||
|
|
||||||
|
from .base_dedumper import Dedumper
|
||||||
|
|
||||||
|
_dedump.register("de", Dedumper)
|
||||||
37
nimbus_extension/components/dedump/base_dedumper.py
Normal file
37
nimbus_extension/components/dedump/base_dedumper.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.package import Package
|
||||||
|
|
||||||
|
|
||||||
|
class Dedumper(Iterator):
|
||||||
|
def __init__(self, input_queue=None):
|
||||||
|
super().__init__()
|
||||||
|
self.input_queue = input_queue
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _next(self) -> Package:
|
||||||
|
try:
|
||||||
|
self.logger.info("Dedumper try to get package from queue")
|
||||||
|
package = self.input_queue.get()
|
||||||
|
self.logger.info(f"get task {package.task_name} package from queue")
|
||||||
|
st = time.time()
|
||||||
|
|
||||||
|
assert isinstance(package, Package), f"the transfered data type must be Package, but it is {type(package)}"
|
||||||
|
if package.should_stop():
|
||||||
|
self.logger.info("received stop signal")
|
||||||
|
raise StopIteration()
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
package.is_serialized() and package.task_id >= 0
|
||||||
|
), "received data must be deserialized and task id must be greater than 0"
|
||||||
|
package.deserialize()
|
||||||
|
self.collect_compute_frame_info(1, time.time() - st)
|
||||||
|
return package
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more packages to process.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error during dedumping: {e}")
|
||||||
|
raise e
|
||||||
5
nimbus_extension/components/dump/__init__.py
Normal file
5
nimbus_extension/components/dump/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import nimbus.components.dump as _dump
|
||||||
|
|
||||||
|
from .env_dumper import EnvDumper
|
||||||
|
|
||||||
|
_dump.register("env", EnvDumper)
|
||||||
10
nimbus_extension/components/dump/env_dumper.py
Normal file
10
nimbus_extension/components/dump/env_dumper.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from nimbus.components.dump import BaseDumper
|
||||||
|
|
||||||
|
|
||||||
|
class EnvDumper(BaseDumper):
|
||||||
|
def __init__(self, data_iter, output_queue=None):
|
||||||
|
super().__init__(data_iter, output_queue=output_queue)
|
||||||
|
|
||||||
|
def dump(self, seq, obs):
|
||||||
|
ser_obj = self.scene.wf.dump_plan_info()
|
||||||
|
return ser_obj
|
||||||
8
nimbus_extension/components/load/__init__.py
Normal file
8
nimbus_extension/components/load/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import nimbus.components.load as _load
|
||||||
|
|
||||||
|
from .env_loader import EnvLoader
|
||||||
|
from .env_randomizer import EnvRandomizer
|
||||||
|
|
||||||
|
_load.register_loader("env_loader", EnvLoader)
|
||||||
|
|
||||||
|
_load.register_randomizer("env_randomizer", EnvRandomizer)
|
||||||
180
nimbus_extension/components/load/env_loader.py
Normal file
180
nimbus_extension/components/load/env_loader.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import time
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.package import Package
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.load import SceneLoader
|
||||||
|
from nimbus.daemon import ComponentStatus, StatusReporter
|
||||||
|
from nimbus.daemon.decorators import status_monitor
|
||||||
|
from nimbus.utils.flags import get_random_seed
|
||||||
|
from workflows.base import create_workflow
|
||||||
|
|
||||||
|
|
||||||
|
class EnvLoader(SceneLoader):
|
||||||
|
"""
|
||||||
|
Environment loader that initializes Isaac Sim and loads scenes based on workflow configurations.
|
||||||
|
|
||||||
|
This loader integrates with the workflow system to manage scene loading and task execution.
|
||||||
|
It supports two operating modes:
|
||||||
|
- Standalone mode (pack_iter=None): Loads tasks directly from workflow configuration
|
||||||
|
- Pipeline mode (pack_iter provided): Loads tasks from a package iterator
|
||||||
|
|
||||||
|
It also supports task repetition for data augmentation across different random seeds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_iter (Iterator[Package]): An iterator from the previous component. None for standalone.
|
||||||
|
cfg_path (str): Path to the workflow configuration file.
|
||||||
|
workflow_type (str): Type of workflow to create (e.g., 'SimBoxDualWorkFlow').
|
||||||
|
simulator (dict): Simulator configuration including physics_dt, rendering_dt, headless, etc.
|
||||||
|
task_repeat (int): How many times to repeat each task before advancing (-1 means single execution).
|
||||||
|
need_preload (bool): Whether to preload assets on scene initialization.
|
||||||
|
scene_info (str): Configuration key for scene information in the workflow config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pack_iter: Iterator[Package],
|
||||||
|
cfg_path: str,
|
||||||
|
workflow_type: str,
|
||||||
|
simulator: dict,
|
||||||
|
task_repeat: int = -1,
|
||||||
|
need_preload: bool = False,
|
||||||
|
scene_info: str = "dining_room_scene_info",
|
||||||
|
):
|
||||||
|
init_start_time = time.time()
|
||||||
|
super().__init__(pack_iter)
|
||||||
|
|
||||||
|
self.status_reporter = StatusReporter(self.__class__.__name__)
|
||||||
|
self.status_reporter.update_status(ComponentStatus.IDLE)
|
||||||
|
self.need_preload = need_preload
|
||||||
|
self.task_repeat_cnt = task_repeat
|
||||||
|
self.task_repeat_idx = 0
|
||||||
|
self.workflow_type = workflow_type
|
||||||
|
|
||||||
|
# Parse simulator config
|
||||||
|
physics_dt = simulator.get("physics_dt", "1/30")
|
||||||
|
rendering_dt = simulator.get("rendering_dt", "1/30")
|
||||||
|
if isinstance(physics_dt, str):
|
||||||
|
physics_dt = float(Fraction(physics_dt))
|
||||||
|
if isinstance(rendering_dt, str):
|
||||||
|
rendering_dt = float(Fraction(rendering_dt))
|
||||||
|
|
||||||
|
from isaacsim import SimulationApp
|
||||||
|
|
||||||
|
self.simulation_app = SimulationApp(
|
||||||
|
{
|
||||||
|
"headless": simulator.get("headless", True),
|
||||||
|
"anti_aliasing": simulator.get("anti_aliasing", 3),
|
||||||
|
"multi_gpu": simulator.get("multi_gpu", True),
|
||||||
|
"renderer": simulator.get("renderer", "RayTracedLighting"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"simulator params: physics dt={physics_dt}, rendering dt={rendering_dt}")
|
||||||
|
from omni.isaac.core import World
|
||||||
|
|
||||||
|
world = World(
|
||||||
|
physics_dt=physics_dt,
|
||||||
|
rendering_dt=rendering_dt,
|
||||||
|
stage_units_in_meters=simulator.get("stage_units_in_meters", 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import workflow extensions and create workflow
|
||||||
|
from workflows import import_extensions
|
||||||
|
|
||||||
|
import_extensions(workflow_type)
|
||||||
|
self.workflow = create_workflow(
|
||||||
|
workflow_type,
|
||||||
|
world,
|
||||||
|
cfg_path,
|
||||||
|
scene_info=scene_info,
|
||||||
|
random_seed=get_random_seed(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scene = None
|
||||||
|
self.task_finish = False
|
||||||
|
self.cur_index = 0
|
||||||
|
self.record_init_time(time.time() - init_start_time)
|
||||||
|
|
||||||
|
self.status_reporter.update_status(ComponentStatus.READY)
|
||||||
|
|
||||||
|
@status_monitor()
|
||||||
|
def _init_next_task(self):
|
||||||
|
"""
|
||||||
|
Internal helper method to initialize and return the next task as a Scene object.
|
||||||
|
|
||||||
|
Handles task repetition logic and advances the task index when all repetitions are complete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scene: Initialized scene object for the next task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StopIteration: When all tasks have been exhausted.
|
||||||
|
"""
|
||||||
|
if self.scene is not None and self.task_repeat_cnt > 0 and self.task_repeat_idx < self.task_repeat_cnt:
|
||||||
|
self.logger.info(f"Task execute times {self.task_repeat_idx + 1}/{self.task_repeat_cnt}")
|
||||||
|
self.workflow.init_task(self.cur_index - 1, self.need_preload)
|
||||||
|
self.task_repeat_idx += 1
|
||||||
|
scene = Scene(
|
||||||
|
name=self.workflow.get_task_name(),
|
||||||
|
wf=self.workflow,
|
||||||
|
task_id=self.cur_index - 1,
|
||||||
|
task_exec_num=self.task_repeat_idx,
|
||||||
|
simulation_app=self.simulation_app,
|
||||||
|
)
|
||||||
|
return scene
|
||||||
|
if self.cur_index >= len(self.workflow.task_cfgs):
|
||||||
|
self.logger.info("No more tasks to load, stopping iteration.")
|
||||||
|
raise StopIteration
|
||||||
|
self.logger.info(f"Loading task {self.cur_index + 1}/{len(self.workflow.task_cfgs)}")
|
||||||
|
self.workflow.init_task(self.cur_index, self.need_preload)
|
||||||
|
self.task_repeat_idx = 1
|
||||||
|
scene = Scene(
|
||||||
|
name=self.workflow.get_task_name(),
|
||||||
|
wf=self.workflow,
|
||||||
|
task_id=self.cur_index,
|
||||||
|
task_exec_num=self.task_repeat_idx,
|
||||||
|
simulation_app=self.simulation_app,
|
||||||
|
)
|
||||||
|
self.cur_index += 1
|
||||||
|
return scene
|
||||||
|
|
||||||
|
def load_asset(self) -> Scene:
|
||||||
|
"""
|
||||||
|
Load and initialize the next scene from workflow.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- Standalone: Iterates through workflow tasks directly
|
||||||
|
- Pipeline: Synchronizes with incoming packages and applies plan info to scene
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scene: The loaded and initialized Scene object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StopIteration: When no more scenes are available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Standalone mode: load tasks directly from workflow
|
||||||
|
if self.pack_iter is None:
|
||||||
|
self.scene = self._init_next_task()
|
||||||
|
# Pipeline mode: load tasks from package iterator
|
||||||
|
else:
|
||||||
|
package = next(self.pack_iter)
|
||||||
|
self.cur_index = package.task_id
|
||||||
|
|
||||||
|
# Initialize scene if this is the first package or a new task
|
||||||
|
if self.scene is None:
|
||||||
|
self.scene = self._init_next_task()
|
||||||
|
elif self.cur_index > self.scene.task_id:
|
||||||
|
self.scene = self._init_next_task()
|
||||||
|
|
||||||
|
# Apply plan information from package to scene
|
||||||
|
package.data = self.scene.wf.dedump_plan_info(package.data)
|
||||||
|
self.scene.add_plan_info(package.data)
|
||||||
|
|
||||||
|
return self.scene
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
46
nimbus_extension/components/load/env_randomizer.py
Normal file
46
nimbus_extension/components/load/env_randomizer.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.load import LayoutRandomizer
|
||||||
|
|
||||||
|
|
||||||
|
class EnvRandomizer(LayoutRandomizer):
|
||||||
|
"""
|
||||||
|
Environment randomizer that extends the base layout randomizer to include additional randomization
|
||||||
|
capabilities specific to the simulation environment.
|
||||||
|
This class can be used to randomize various aspects of the environment, such as object placements,
|
||||||
|
textures, lighting conditions, and other scene parameters, based on the provided configuration.
|
||||||
|
The randomization process can be controlled through the number of randomizations to perform and
|
||||||
|
whether to operate in strict mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_iter (Iterator[Scene]): An iterator that yields scenes to be randomized.
|
||||||
|
random_num (int): How many randomizations to perform for each scene.
|
||||||
|
strict_mode (bool): Whether to operate in strict mode, which enforces certain constraints
|
||||||
|
on the randomization process.
|
||||||
|
input_dir (str): Directory from which to load additional randomization data such as object
|
||||||
|
placements or textures. If None, randomization is performed without loading additional data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, scene_iter: Iterator[Scene], random_num: int = 1, strict_mode: bool = False, input_dir: str = None
|
||||||
|
):
|
||||||
|
super().__init__(scene_iter, random_num, strict_mode)
|
||||||
|
assert self.random_num > 0, "random_num must be greater than 0"
|
||||||
|
self.input_dir = input_dir
|
||||||
|
if self.input_dir is not None:
|
||||||
|
self.paths_names = os.listdir(self.input_dir)
|
||||||
|
self.random_num = len(self.paths_names)
|
||||||
|
|
||||||
|
def randomize_scene(self, scene) -> Scene:
|
||||||
|
if scene.plan_info is None:
|
||||||
|
path = None
|
||||||
|
if self.input_dir is not None:
|
||||||
|
path = os.path.join(self.input_dir, self.paths_names[self.cur_index])
|
||||||
|
if not scene.wf.randomization(path):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if not scene.wf.randomization_from_mem(scene.plan_info):
|
||||||
|
return None
|
||||||
|
return scene
|
||||||
5
nimbus_extension/components/plan_with_render/__init__.py
Normal file
5
nimbus_extension/components/plan_with_render/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import nimbus.components.plan_with_render as _pwr
|
||||||
|
|
||||||
|
from .plan_with_render import EnvPlanWithRender
|
||||||
|
|
||||||
|
_pwr.register("plan_with_render", EnvPlanWithRender)
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.observation import Observations
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.daemon.decorators import status_monitor
|
||||||
|
from nimbus.utils.flags import is_debug_mode
|
||||||
|
|
||||||
|
|
||||||
|
class EnvPlanWithRender(Iterator):
|
||||||
|
"""
|
||||||
|
A component that integrates planning and rendering for a given scene. It takes an iterator of scenes as
|
||||||
|
input, performs planning and rendering for each scene, and produces sequences and observations as output.
|
||||||
|
The component manages the planning and rendering process, including tracking the current episode and
|
||||||
|
collecting performance metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_iter (Iterator[Scene]): An iterator that yields scenes to be processed for planning and rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_iter: Iterator[Scene]):
|
||||||
|
super().__init__()
|
||||||
|
self.scene_iter = scene_iter
|
||||||
|
self.episodes = 1
|
||||||
|
self.current_episode = sys.maxsize
|
||||||
|
self.scene = None
|
||||||
|
|
||||||
|
@status_monitor()
|
||||||
|
def plan_with_render(self):
|
||||||
|
wf = self.scene.wf
|
||||||
|
obs_num = wf.plan_with_render()
|
||||||
|
if obs_num <= 0:
|
||||||
|
return None
|
||||||
|
# Assuming rgb is a dictionary of lists, get the length from one of the lists.
|
||||||
|
obs = Observations(self.scene.name, str(self.current_episode), length=obs_num)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _next(self):
|
||||||
|
try:
|
||||||
|
if self.scene is None or self.current_episode >= self.episodes:
|
||||||
|
try:
|
||||||
|
self.scene = next(self.scene_iter)
|
||||||
|
self.current_episode = 0
|
||||||
|
if self.scene is None:
|
||||||
|
return None, None, None
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more scene to process.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error loading next scene: {e}")
|
||||||
|
if is_debug_mode():
|
||||||
|
raise e
|
||||||
|
self.current_episode = sys.maxsize
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
compute_start_time = time.time()
|
||||||
|
obs = self.plan_with_render()
|
||||||
|
compute_end_time = time.time()
|
||||||
|
self.current_episode += 1
|
||||||
|
|
||||||
|
if obs is not None:
|
||||||
|
self.collect_compute_frame_info(obs.get_length(), compute_end_time - compute_start_time)
|
||||||
|
return self.scene, None, obs
|
||||||
|
|
||||||
|
if self.current_episode >= self.episodes:
|
||||||
|
return self.scene, None, None
|
||||||
|
|
||||||
|
self.logger.info(f"Generate seq failed and retry. Current episode id is {self.current_episode}")
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration("No more scene to process.")
|
||||||
|
except Exception as e:
|
||||||
|
scene_name = getattr(self.scene, "name", "<unknown>")
|
||||||
|
self.logger.exception(
|
||||||
|
f"Error during idx {self.current_episode} sequence plan with render for scene {scene_name}: {e}"
|
||||||
|
)
|
||||||
|
if is_debug_mode():
|
||||||
|
raise e
|
||||||
|
self.current_episode += 1
|
||||||
|
return self.scene, None, None
|
||||||
7
nimbus_extension/components/planner/__init__.py
Normal file
7
nimbus_extension/components/planner/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import nimbus.components.planner as _planner
|
||||||
|
|
||||||
|
from .env_planner import EnvSeqPlanner
|
||||||
|
from .env_reader import EnvReader
|
||||||
|
|
||||||
|
_planner.register("env_planner", EnvSeqPlanner)
|
||||||
|
_planner.register("env_reader", EnvReader)
|
||||||
25
nimbus_extension/components/planner/env_planner.py
Normal file
25
nimbus_extension/components/planner/env_planner.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.data.sequence import Sequence
|
||||||
|
from nimbus.components.planner import SequencePlanner
|
||||||
|
|
||||||
|
|
||||||
|
class EnvSeqPlanner(SequencePlanner):
|
||||||
|
"""
|
||||||
|
A sequence planner that generates sequences based on the environment's workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_iter (Iterator[Scene]): An iterator that provides scenes to be processed for sequence planning.
|
||||||
|
planner_cfg (dict): A dictionary containing configuration parameters for the planner,
|
||||||
|
such as the type of planner to use and its arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_iter: Iterator[Scene], planner_cfg: dict):
|
||||||
|
super().__init__(scene_iter, planner_cfg, episodes=1)
|
||||||
|
|
||||||
|
def generate_sequence(self):
|
||||||
|
wf = self.scene.wf
|
||||||
|
sequence = wf.generate_seq()
|
||||||
|
if len(sequence) <= 0:
|
||||||
|
return None
|
||||||
|
return Sequence(self.scene.name, str(self.current_episode), length=len(sequence), data=sequence)
|
||||||
32
nimbus_extension/components/planner/env_reader.py
Normal file
32
nimbus_extension/components/planner/env_reader.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from nimbus.components.data.iterator import Iterator
|
||||||
|
from nimbus.components.data.scene import Scene
|
||||||
|
from nimbus.components.data.sequence import Sequence
|
||||||
|
from nimbus.components.planner import SequencePlanner
|
||||||
|
|
||||||
|
|
||||||
|
class EnvReader(SequencePlanner):
|
||||||
|
"""
|
||||||
|
A sequence planner that generates sequences based on the environment's workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_iter (Iterator[Scene]): An iterator that provides scenes to be processed for sequence planning.
|
||||||
|
planner_cfg (dict): A dictionary containing configuration parameters for the planner,
|
||||||
|
such as the type of planner to use and its arguments.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_iter: Iterator[Scene], planner_cfg: dict):
|
||||||
|
super().__init__(scene_iter, planner_cfg, episodes=1)
|
||||||
|
|
||||||
|
def generate_sequence(self):
|
||||||
|
wf = self.scene.wf
|
||||||
|
if self.scene.plan_info is None:
|
||||||
|
sequence = wf.recover_seq(None)
|
||||||
|
else:
|
||||||
|
sequence = wf.recover_seq_from_mem(self.scene.plan_info)
|
||||||
|
if len(sequence) == 0:
|
||||||
|
return None
|
||||||
|
return Sequence(self.scene.name, str(self.current_episode), length=len(sequence), data=sequence)
|
||||||
|
|
||||||
|
def _initialize(self, scene):
|
||||||
|
pass
|
||||||
5
nimbus_extension/components/render/__init__.py
Normal file
5
nimbus_extension/components/render/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import nimbus.components.render as _render
|
||||||
|
|
||||||
|
from .env_renderer import EnvRenderer
|
||||||
|
|
||||||
|
_render.register("env_renderer", EnvRenderer)
|
||||||
25
nimbus_extension/components/render/env_renderer.py
Normal file
25
nimbus_extension/components/render/env_renderer.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from nimbus.components.data.observation import Observations
|
||||||
|
from nimbus.components.render import BaseRenderer
|
||||||
|
|
||||||
|
|
||||||
|
class EnvRenderer(BaseRenderer):
|
||||||
|
"""
|
||||||
|
Renderer for environment simulation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scene_seq_iter):
|
||||||
|
super().__init__(scene_seq_iter)
|
||||||
|
|
||||||
|
def _lazy_init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _close_resource(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate_obs(self, seq):
|
||||||
|
wf = self.scene.wf
|
||||||
|
obs_num = wf.seq_replay(seq.data)
|
||||||
|
if obs_num <= 0:
|
||||||
|
return None
|
||||||
|
obs = Observations(seq.scene_name, seq.index, length=obs_num)
|
||||||
|
return obs
|
||||||
5
nimbus_extension/components/store/__init__.py
Normal file
5
nimbus_extension/components/store/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import nimbus.components.store as _store
|
||||||
|
|
||||||
|
from .env_writer import EnvWriter
|
||||||
|
|
||||||
|
_store.register("env_writer", EnvWriter)
|
||||||
58
nimbus_extension/components/store/env_writer.py
Normal file
58
nimbus_extension/components/store/env_writer.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from nimbus.components.store import BaseWriter
|
||||||
|
|
||||||
|
|
||||||
|
class EnvWriter(BaseWriter):
|
||||||
|
"""
|
||||||
|
A writer that saves generated sequences and observations to disk for environment simulations.
|
||||||
|
This class extends the BaseWriter to provide specific implementations for handling data related
|
||||||
|
to environment simulations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_iter (Iterator): An iterator that provides data to be written, typically containing scenes,
|
||||||
|
sequences, and observations.
|
||||||
|
seq_output_dir (str): The directory where generated sequences will be saved. Can be None
|
||||||
|
if sequence output is not needed.
|
||||||
|
obs_output_dir (str): The directory where generated observations will be saved. Can be None
|
||||||
|
if observation output is not needed.
|
||||||
|
batch_async (bool): If True, the writer will use asynchronous batch writing to improve performance
|
||||||
|
when handling large amounts of data. Default is True.
|
||||||
|
async_threshold (int): The maximum number of asynchronous write operations that can be in progress
|
||||||
|
at the same time. If the threshold is reached, the writer will wait for the oldest operation
|
||||||
|
to complete before starting a new one. Default is 1.
|
||||||
|
batch_size (int): The number of data items to write in each batch when using asynchronous writing.
|
||||||
|
Default is 1, and it will be capped at 8 to prevent potential issues with too many concurrent operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, data_iter, seq_output_dir=None, output_dir=None, batch_async=True, async_threshold=1, batch_size=1
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
data_iter,
|
||||||
|
seq_output_dir,
|
||||||
|
output_dir,
|
||||||
|
batch_async=batch_async,
|
||||||
|
async_threshold=async_threshold,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flush_to_disk(self, task, scene_name, seq, obs):
|
||||||
|
try:
|
||||||
|
scene_name = self.scene.name
|
||||||
|
if obs is not None and self.obs_output_dir is not None:
|
||||||
|
log_dir = os.path.join(self.obs_output_dir, scene_name)
|
||||||
|
self.logger.info(f"Try to save obs in {log_dir}")
|
||||||
|
length = task.save(log_dir)
|
||||||
|
self.logger.info(f"Saved {length} obs output saved in {log_dir}")
|
||||||
|
elif seq is not None and self.seq_output_dir is not None:
|
||||||
|
log_dir = os.path.join(self.seq_output_dir, scene_name)
|
||||||
|
self.logger.info(f"Try to save seq in {log_dir}")
|
||||||
|
length = task.save_seq(log_dir)
|
||||||
|
self.logger.info(f"Saved {length} seq output saved in {log_dir}")
|
||||||
|
else:
|
||||||
|
self.logger.info("Skip this storage")
|
||||||
|
return length
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.info(f"Failed to save data for scene {scene_name}: {e}")
|
||||||
|
raise e
|
||||||
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
[tool.pytest.ini_options]
|
||||||
|
# Test configuration for better isolation
|
||||||
|
addopts = [
|
||||||
|
"--verbose",
|
||||||
|
"--tb=short",
|
||||||
|
"--strict-markers",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test discovery
|
||||||
|
testpaths = ["test"]
|
||||||
|
python_files = ["test_*.py", "*_test.py"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
|
||||||
|
# Warnings
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
"ignore::PendingDeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Markers
|
||||||
|
markers = [
|
||||||
|
"unit: Unit tests",
|
||||||
|
"integration: Integration tests",
|
||||||
|
"slow: Slow tests",
|
||||||
|
"isaac_sim: Tests requiring Isaac Sim",
|
||||||
|
]
|
||||||
8
pyrightconfig.json
Normal file
8
pyrightconfig.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"exclude": [
|
||||||
|
"workflows/simbox/assets",
|
||||||
|
"workflows/simbox/curobo",
|
||||||
|
"workflows/simbox/panda_drake",
|
||||||
|
"profile"
|
||||||
|
]
|
||||||
|
}
|
||||||
22
requirements.txt
Normal file
22
requirements.txt
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
open3d-cpu
|
||||||
|
opencv-python
|
||||||
|
pathfinding
|
||||||
|
imageio[ffmpeg]
|
||||||
|
plyfile
|
||||||
|
omegaconf
|
||||||
|
pydantic
|
||||||
|
toml
|
||||||
|
shapely
|
||||||
|
ray
|
||||||
|
pympler
|
||||||
|
scikit-image
|
||||||
|
lmdb
|
||||||
|
setuptools
|
||||||
|
wheel
|
||||||
|
drake
|
||||||
|
colored
|
||||||
|
transforms3d
|
||||||
|
concave-hull
|
||||||
|
tomli
|
||||||
|
ninja
|
||||||
|
usd-core==24.11
|
||||||
78
scripts/download_assets.sh
Normal file
78
scripts/download_assets.sh
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Download assets from HuggingFace.
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/download_assets.sh [OPTIONS]
|
||||||
|
#
|
||||||
|
# Options:
|
||||||
|
# --min Download only required scene assets (for quick testing)
|
||||||
|
# --full Download all scene assets including all robots and tasks (default)
|
||||||
|
# --with-curobo Also download CuRobo package
|
||||||
|
# --with-drake Also download panda_drake package
|
||||||
|
# --local-dir DIR Where to save (default: current directory)
|
||||||
|
#
|
||||||
|
# Examples:
|
||||||
|
# bash scripts/download_assets.sh --min
|
||||||
|
# bash scripts/download_assets.sh --full --with-curobo --with-drake
|
||||||
|
# bash scripts/download_assets.sh --min --with-curobo --local-dir /data/assets
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
REPO_ID="InternRobotics/InternData-A1"
|
||||||
|
REPO_TYPE="dataset"
|
||||||
|
ASSET_PREFIX="InternDataAssets"
|
||||||
|
|
||||||
|
MODE="full"
|
||||||
|
LOCAL_DIR="."
|
||||||
|
WITH_CUROBO=false
|
||||||
|
WITH_DRAKE=false
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--min) MODE="min"; shift ;;
|
||||||
|
--full) MODE="full"; shift ;;
|
||||||
|
--with-curobo) WITH_CUROBO=true; shift ;;
|
||||||
|
--with-drake) WITH_DRAKE=true; shift ;;
|
||||||
|
--local-dir) LOCAL_DIR="$2"; shift 2 ;;
|
||||||
|
*) echo "Unknown option: $1"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
info() { echo -e "\033[32m[INFO]\033[0m $*"; }
|
||||||
|
|
||||||
|
download() {
|
||||||
|
info "Downloading $2 ..."
|
||||||
|
huggingface-cli download "$REPO_ID" --repo-type "$REPO_TYPE" --include "$1" --local-dir "$LOCAL_DIR"
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Scene assets: required (both modes) ---
|
||||||
|
info "========== Downloading required scene assets =========="
|
||||||
|
REQUIRED_DIRS=("background_textures" "envmap_lib" "floor_textures" "table_textures" "table0")
|
||||||
|
for dir in "${REQUIRED_DIRS[@]}"; do
|
||||||
|
download "${ASSET_PREFIX}/assets/${dir}/*" "$dir"
|
||||||
|
done
|
||||||
|
download "${ASSET_PREFIX}/assets/table_info.json" "table_info.json"
|
||||||
|
|
||||||
|
# --- Scene assets: full only (all robots + all tasks) ---
|
||||||
|
if [[ "$MODE" == "full" ]]; then
|
||||||
|
info "========== Downloading all robots and tasks =========="
|
||||||
|
for robot in lift2 franka frankarobotiq split_aloha_mid_360 G1_120s; do
|
||||||
|
download "${ASSET_PREFIX}/assets/${robot}/*" "robot: ${robot}"
|
||||||
|
done
|
||||||
|
for task in basic art long_horizon pick_and_place; do
|
||||||
|
download "${ASSET_PREFIX}/assets/${task}/*" "task: ${task}"
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# --- CuRobo ---
|
||||||
|
if [[ "$WITH_CUROBO" == true ]]; then
|
||||||
|
info "========== Downloading CuRobo =========="
|
||||||
|
download "${ASSET_PREFIX}/curobo/*" "curobo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# --- panda_drake ---
|
||||||
|
if [[ "$WITH_DRAKE" == true ]]; then
|
||||||
|
info "========== Downloading panda_drake =========="
|
||||||
|
download "${ASSET_PREFIX}/panda_drake/*" "panda_drake"
|
||||||
|
fi
|
||||||
|
|
||||||
|
info "Done! (mode=${MODE}, curobo=${WITH_CUROBO}, drake=${WITH_DRAKE}, local-dir=${LOCAL_DIR})"
|
||||||
70
scripts/simbox/simbox_pipe.sh
Normal file
70
scripts/simbox/simbox_pipe.sh
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
cfg_path="$1"
|
||||||
|
|
||||||
|
if [ $# -lt 1 ]; then
|
||||||
|
echo "Error: Missing required parameter"
|
||||||
|
echo "Usage: bash $0 <config_path> [random_num] [random_seed]"
|
||||||
|
echo ""
|
||||||
|
echo "Parameters:"
|
||||||
|
echo " config_path - Full path to the config file (with .yml/.yaml extension)"
|
||||||
|
echo " random_num - (Optional) Number of samples to generate (default: 10)"
|
||||||
|
echo " random_seed - (Optional) Random seed for reproducibility"
|
||||||
|
echo " scene_info - (Optional) Scene info key to use"
|
||||||
|
echo ""
|
||||||
|
echo "Example:"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10 42"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10 42 living_room_scene_info"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cfg_path="$1"
|
||||||
|
random_num=10
|
||||||
|
if [ $# -ge 2 ]; then
|
||||||
|
random_num="$2"
|
||||||
|
fi
|
||||||
|
random_seed=""
|
||||||
|
if [ $# -ge 3 ]; then
|
||||||
|
random_seed="$3"
|
||||||
|
fi
|
||||||
|
scene_info=""
|
||||||
|
if [ $# -ge 4 ]; then
|
||||||
|
scene_info="$4"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$cfg_path" ]; then
|
||||||
|
echo "Error: Config path parameter is required and cannot be empty"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract custom_path and config_name from the full path
|
||||||
|
custom_path=$(dirname "$cfg_path")
|
||||||
|
config_name=$(basename "$cfg_path" .yaml)
|
||||||
|
|
||||||
|
echo "Config path: $cfg_path"
|
||||||
|
echo "Custom path: $custom_path"
|
||||||
|
echo "Config name: $config_name"
|
||||||
|
echo "Random num: $random_num"
|
||||||
|
|
||||||
|
if [ ! -f "$cfg_path" ]; then
|
||||||
|
echo "Error: Configuration file not found: $cfg_path"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
name_with_split="${config_name}_pipe${random_seed:+_seed_${random_seed}}"
|
||||||
|
|
||||||
|
echo "Running with config: $cfg_path"
|
||||||
|
echo "Output name: $name_with_split"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
/isaac-sim/python.sh launcher.py --config configs/simbox/de_pipe_template.yaml \
|
||||||
|
--name="$name_with_split" \
|
||||||
|
--load_stage.scene_loader.args.cfg_path="$cfg_path" \
|
||||||
|
--load_stage.layout_random_generator.args.random_num="$random_num" \
|
||||||
|
--dedump_stage.scene_loader.args.cfg_path="$cfg_path" \
|
||||||
|
--store_stage.writer.args.output_dir="output/$name_with_split/" \
|
||||||
|
${scene_info:+--load_stage.env_loader.args.scene_info="$scene_info"} \
|
||||||
|
${random_seed:+--random_seed="$random_seed"}
|
||||||
|
set +x
|
||||||
67
scripts/simbox/simbox_plan_and_render.sh
Normal file
67
scripts/simbox/simbox_plan_and_render.sh
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# -lt 1 ]; then
|
||||||
|
echo "Error: Missing required parameter"
|
||||||
|
echo "Usage: bash $0 <config_path> [random_num] [random_seed]"
|
||||||
|
echo ""
|
||||||
|
echo "Parameters:"
|
||||||
|
echo " config_path - Full path to the config file (with .yml extension)"
|
||||||
|
echo " random_num - (Optional) Number of samples to generate (default: 10)"
|
||||||
|
echo " random_seed - (Optional) Random seed for reproducibility"
|
||||||
|
echo " scene_info - (Optional) Scene info key to use"
|
||||||
|
echo ""
|
||||||
|
echo "Example:"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10 42"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10 42 living_room_scene_info"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cfg_path="$1"
|
||||||
|
random_num=10
|
||||||
|
if [ $# -ge 2 ]; then
|
||||||
|
random_num="$2"
|
||||||
|
fi
|
||||||
|
random_seed=""
|
||||||
|
if [ $# -ge 3 ]; then
|
||||||
|
random_seed="$3"
|
||||||
|
fi
|
||||||
|
scene_info=""
|
||||||
|
if [ $# -ge 4 ]; then
|
||||||
|
scene_info="$4"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$cfg_path" ]; then
|
||||||
|
echo "Error: Config path parameter is required and cannot be empty"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract custom_path and config_name from the full path
|
||||||
|
custom_path=$(dirname "$cfg_path")
|
||||||
|
config_name=$(basename "$cfg_path" .yaml)
|
||||||
|
|
||||||
|
echo "Config path: $cfg_path"
|
||||||
|
echo "Custom path: $custom_path"
|
||||||
|
echo "Config name: $config_name"
|
||||||
|
echo "Random num: $random_num"
|
||||||
|
|
||||||
|
if [ ! -f "$cfg_path" ]; then
|
||||||
|
echo "Error: Configuration file not found: $cfg_path"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
name_with_split="${config_name}_plan_and_render${random_seed:+_seed_${random_seed}}"
|
||||||
|
|
||||||
|
echo "Running with config: $cfg_path"
|
||||||
|
echo "Output name: $name_with_split"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
/isaac-sim/python.sh launcher.py --config configs/simbox/de_plan_and_render_template.yaml \
|
||||||
|
--name="$name_with_split" \
|
||||||
|
--load_stage.scene_loader.args.cfg_path="$cfg_path" \
|
||||||
|
--load_stage.layout_random_generator.args.random_num="$random_num" \
|
||||||
|
--store_stage.writer.args.output_dir="output/$name_with_split/" \
|
||||||
|
${scene_info:+--load_stage.env_loader.args.scene_info="$scene_info"} \
|
||||||
|
${random_seed:+--random_seed="$random_seed"}
|
||||||
|
set +x
|
||||||
67
scripts/simbox/simbox_plan_with_render.sh
Normal file
67
scripts/simbox/simbox_plan_with_render.sh
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# -lt 1 ]; then
|
||||||
|
echo "Error: Missing required parameter"
|
||||||
|
echo "Usage: bash $0 <config_path> [random_num] [random_seed]"
|
||||||
|
echo ""
|
||||||
|
echo "Parameters:"
|
||||||
|
echo " config_path - Full path to the config file (with .yml extension)"
|
||||||
|
echo " random_num - (Optional) Number of samples to generate (default: 10)"
|
||||||
|
echo " random_seed - (Optional) Random seed for reproducibility"
|
||||||
|
echo " scene_info - (Optional) Scene info key to use"
|
||||||
|
echo ""
|
||||||
|
echo "Example:"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10 42"
|
||||||
|
echo " bash $0 workflows/simbox/core/configs/tasks/long_horizon/split_aloha/sort_the_rubbish/sort_the_rubbish_part0.yaml 10 42 living_room_scene_info"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cfg_path="$1"
|
||||||
|
random_num=10
|
||||||
|
if [ $# -ge 2 ]; then
|
||||||
|
random_num="$2"
|
||||||
|
fi
|
||||||
|
random_seed=""
|
||||||
|
if [ $# -ge 3 ]; then
|
||||||
|
random_seed="$3"
|
||||||
|
fi
|
||||||
|
scene_info=""
|
||||||
|
if [ $# -ge 4 ]; then
|
||||||
|
scene_info="$4"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$cfg_path" ]; then
|
||||||
|
echo "Error: Config path parameter is required and cannot be empty"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract custom_path and config_name from the full path
|
||||||
|
custom_path=$(dirname "$cfg_path")
|
||||||
|
config_name=$(basename "$cfg_path" .yaml)
|
||||||
|
|
||||||
|
echo "Config path: $cfg_path"
|
||||||
|
echo "Custom path: $custom_path"
|
||||||
|
echo "Config name: $config_name"
|
||||||
|
echo "Random num: $random_num"
|
||||||
|
|
||||||
|
if [ ! -f "$cfg_path" ]; then
|
||||||
|
echo "Error: Configuration file not found: $cfg_path"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
name_with_split="${config_name}_plan_with_render${random_seed:+_seed_${random_seed}}"
|
||||||
|
|
||||||
|
echo "Running with config: $cfg_path"
|
||||||
|
echo "Output name: $name_with_split"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
/isaac-sim/python.sh launcher.py --config configs/simbox/de_plan_with_render_template.yaml \
|
||||||
|
--name="$name_with_split" \
|
||||||
|
--load_stage.scene_loader.args.cfg_path="$cfg_path" \
|
||||||
|
--load_stage.layout_random_generator.args.random_num="$random_num" \
|
||||||
|
--store_stage.writer.args.output_dir="output/$name_with_split/" \
|
||||||
|
${scene_info:+--load_stage.env_loader.args.scene_info="$scene_info"} \
|
||||||
|
${random_seed:+--random_seed="$random_seed"}
|
||||||
|
set +x
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Tests module for DataEngine
|
||||||
1
tests/integration/__init__.py
Normal file
1
tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Integration tests module
|
||||||
1
tests/integration/base/__init__.py
Normal file
1
tests/integration/base/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Base test infrastructure
|
||||||
401
tests/integration/base/test_harness.py
Normal file
401
tests/integration/base/test_harness.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
"""
|
||||||
|
Integration Test Harness base class that encapsulates common test logic.
|
||||||
|
Reduces boilerplate code and provides a consistent testing interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from .utils import set_test_seeds
|
||||||
|
|
||||||
|
# Add paths to sys.path for proper imports
|
||||||
|
sys.path.append("./")
|
||||||
|
sys.path.append("./data_engine")
|
||||||
|
|
||||||
|
|
||||||
|
# Import data_engine modules only when needed to avoid import errors during framework testing
|
||||||
|
def _import_data_engine_modules():
|
||||||
|
"""Import data_engine modules when they are actually needed."""
|
||||||
|
try:
|
||||||
|
from nimbus import run_data_engine
|
||||||
|
from nimbus.utils.config_processor import ConfigProcessor
|
||||||
|
from nimbus.utils.utils import init_env
|
||||||
|
|
||||||
|
return init_env, run_data_engine, ConfigProcessor
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Failed to import data_engine modules. Ensure dependencies are installed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestHarness:
|
||||||
|
"""
|
||||||
|
Base class for integration tests that provides common functionality.
|
||||||
|
|
||||||
|
This class encapsulates:
|
||||||
|
- Configuration loading and processing
|
||||||
|
- Output directory cleanup
|
||||||
|
- Test pipeline execution (direct or subprocess)
|
||||||
|
- Output validation
|
||||||
|
- Data comparison with reference
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str, seed: int = 42, load_num: int = 0, random_num: int = 0, episodes: int = 0):
|
||||||
|
"""
|
||||||
|
Initialize the test harness with configuration and seed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the test configuration YAML file
|
||||||
|
seed: Random seed for reproducible results (default: 42)
|
||||||
|
"""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.seed = seed
|
||||||
|
self.load_num = load_num
|
||||||
|
self.random_num = random_num
|
||||||
|
self.episodes = episodes
|
||||||
|
self.output_dir = None
|
||||||
|
|
||||||
|
# Import data_engine modules
|
||||||
|
self._init_env, self._run_data_engine, self._ConfigProcessor = _import_data_engine_modules()
|
||||||
|
self.processor = self._ConfigProcessor()
|
||||||
|
|
||||||
|
# Initialize environment (same as launcher.py)
|
||||||
|
self._init_env()
|
||||||
|
|
||||||
|
self.config = self.load_and_process_config() if config_path else None
|
||||||
|
# Set random seeds for reproducibility
|
||||||
|
self.modify_config()
|
||||||
|
|
||||||
|
set_test_seeds(seed)
|
||||||
|
|
||||||
|
from nimbus.utils.flags import set_debug_mode, set_random_seed
|
||||||
|
|
||||||
|
set_debug_mode(True) # Enable debug mode for better error visibility during tests
|
||||||
|
set_random_seed(seed)
|
||||||
|
|
||||||
|
def modify_config(self):
|
||||||
|
"""Modify configuration parameters as needed before running the pipeline."""
|
||||||
|
if self.config and "load_stage" in self.config:
|
||||||
|
if "layout_random_generator" in self.config.load_stage:
|
||||||
|
if "args" in self.config.load_stage.layout_random_generator:
|
||||||
|
if self.random_num > 0:
|
||||||
|
self.config.load_stage.layout_random_generator.args.random_num = self.random_num
|
||||||
|
if "input_dir" in self.config.load_stage.layout_random_generator.args:
|
||||||
|
if "simbox" in self.config.name:
|
||||||
|
input_path = (
|
||||||
|
"/shared/smartbot_new/zhangyuchang/CI/manip/"
|
||||||
|
"simbox/simbox_plan_ci/seq_path/BananaBaseTask/plan"
|
||||||
|
)
|
||||||
|
self.config.load_stage.layout_random_generator.args.input_dir = input_path
|
||||||
|
if self.config and "plan_stage" in self.config:
|
||||||
|
if "seq_planner" in self.config.plan_stage:
|
||||||
|
if "args" in self.config.plan_stage.seq_planner:
|
||||||
|
if self.episodes > 0:
|
||||||
|
self.config.plan_stage.seq_planner.args.episodes = self.episodes
|
||||||
|
if self.load_num > 0:
|
||||||
|
self.config.load_stage.scene_loader.args.load_num = self.load_num
|
||||||
|
if self.config and "name" in self.config:
|
||||||
|
self.config.name = self.config.name + "_ci"
|
||||||
|
if self.config and "store_stage" in self.config:
|
||||||
|
if hasattr(self.config.store_stage.writer.args, "obs_output_dir"):
|
||||||
|
self.config.store_stage.writer.args.obs_output_dir = f"output/{self.config.name}/obs_path/"
|
||||||
|
if hasattr(self.config.store_stage.writer.args, "seq_output_dir"):
|
||||||
|
self.config.store_stage.writer.args.seq_output_dir = f"output/{self.config.name}/seq_path/"
|
||||||
|
if hasattr(self.config.store_stage.writer.args, "output_dir"):
|
||||||
|
self.config.store_stage.writer.args.output_dir = f"output/{self.config.name}/"
|
||||||
|
self._extract_output_dir()
|
||||||
|
|
||||||
|
def load_and_process_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load and process the test configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed configuration dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If config file does not exist
|
||||||
|
"""
|
||||||
|
assert os.path.exists(self.config_path), f"Config file not found: {self.config_path}"
|
||||||
|
|
||||||
|
self.config = self.processor.process_config(self.config_path)
|
||||||
|
self.processor.print_final_config(self.config)
|
||||||
|
|
||||||
|
# Extract output directory from config
|
||||||
|
self._extract_output_dir()
|
||||||
|
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
def _extract_output_dir(self):
|
||||||
|
"""Extract and expand the output directory path from config."""
|
||||||
|
# Try navigation test output path
|
||||||
|
output_dir = self.config.get("store_stage", {}).get("writer", {}).get("args", {}).get("seq_output_dir")
|
||||||
|
|
||||||
|
# If not found, try common output_dir used in render configs
|
||||||
|
if not output_dir:
|
||||||
|
output_dir = self.config.get("store_stage", {}).get("writer", {}).get("args", {}).get("output_dir")
|
||||||
|
|
||||||
|
# Process the output directory if found
|
||||||
|
if output_dir and isinstance(output_dir, str):
|
||||||
|
name = self.config.get("name", "test_output")
|
||||||
|
output_dir = output_dir.replace("${name}", name)
|
||||||
|
self.output_dir = os.path.abspath(output_dir)
|
||||||
|
|
||||||
|
def cleanup_output_directory(self):
|
||||||
|
"""Clean up existing output directory if it exists."""
|
||||||
|
if self.output_dir and os.path.exists(self.output_dir):
|
||||||
|
# Use ignore_errors=True to handle NFS caching issues
|
||||||
|
shutil.rmtree(self.output_dir, ignore_errors=True)
|
||||||
|
# If directory still exists (NFS delay), try removing with onerror handler
|
||||||
|
if os.path.exists(self.output_dir):
|
||||||
|
|
||||||
|
def handle_remove_error(func, path, exc_info): # pylint: disable=W0613
|
||||||
|
"""Handle errors during removal, with retry for NFS issues."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.1) # Brief delay for NFS sync
|
||||||
|
try:
|
||||||
|
if os.path.isdir(path):
|
||||||
|
os.rmdir(path)
|
||||||
|
else:
|
||||||
|
os.remove(path)
|
||||||
|
except OSError:
|
||||||
|
pass # Ignore if still fails
|
||||||
|
|
||||||
|
shutil.rmtree(self.output_dir, onerror=handle_remove_error)
|
||||||
|
print(f"Cleaned up existing output directory: {self.output_dir}")
|
||||||
|
|
||||||
|
def run_data_engine_direct(self, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Run the test pipeline directly in the current process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional arguments to pass to run_data_engine
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If pipeline execution fails
|
||||||
|
"""
|
||||||
|
if not self.config:
|
||||||
|
self.load_and_process_config()
|
||||||
|
self.modify_config()
|
||||||
|
self.cleanup_output_directory()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._run_data_engine(self.config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run_data_engine_subprocess(
|
||||||
|
self,
|
||||||
|
runner_script: str,
|
||||||
|
interpreter: str = "python",
|
||||||
|
timeout: int = 1800,
|
||||||
|
compare_output: bool = False,
|
||||||
|
reference_dir: str = "",
|
||||||
|
comparator: str = "simbox",
|
||||||
|
comparator_args: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> subprocess.CompletedProcess:
|
||||||
|
"""
|
||||||
|
Run the test pipeline in a subprocess using a runner script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner_script: Path to the runner script to execute
|
||||||
|
interpreter: Command to use for running the script (default: "python")
|
||||||
|
timeout: Timeout in seconds for subprocess execution (default: 1800)
|
||||||
|
compare_output: Whether to compare generated output with a reference directory
|
||||||
|
reference_dir: Path to reference directory containing meta_info.pkl and lmdb
|
||||||
|
comparator: Which comparator to use (default: "simbox")
|
||||||
|
comparator_args: Optional extra arguments for comparator (e.g. {"tolerance": 1e-6, "image_psnr": 37.0})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
subprocess.CompletedProcess object with execution results
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
subprocess.TimeoutExpired: If execution exceeds timeout
|
||||||
|
AssertionError: If subprocess returns non-zero exit code
|
||||||
|
"""
|
||||||
|
self.cleanup_output_directory()
|
||||||
|
|
||||||
|
# Build command based on interpreter type
|
||||||
|
if interpreter == "blenderproc":
|
||||||
|
cmd = ["blenderproc", "run", runner_script]
|
||||||
|
if os.environ.get("BLENDER_CUSTOM_PATH"):
|
||||||
|
cmd.extend(["--custom-blender-path", os.environ["BLENDER_CUSTOM_PATH"]])
|
||||||
|
elif interpreter.endswith(".sh"):
|
||||||
|
# For scripts like /isaac-sim/python.sh
|
||||||
|
cmd = [interpreter, runner_script]
|
||||||
|
else:
|
||||||
|
cmd = [interpreter, runner_script]
|
||||||
|
|
||||||
|
if not self.output_dir:
|
||||||
|
self._extract_output_dir()
|
||||||
|
output_dir = self.output_dir
|
||||||
|
|
||||||
|
print(f"Running command: {' '.join(cmd)}")
|
||||||
|
print(f"Expected output directory: {output_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, capture_output=False, text=True, timeout=timeout, check=False)
|
||||||
|
|
||||||
|
# Print subprocess output for debugging
|
||||||
|
if result.stdout:
|
||||||
|
print("STDOUT:", result.stdout[-2000:]) # Last 2000 chars
|
||||||
|
if result.stderr:
|
||||||
|
print("STDERR:", result.stderr[-1000:]) # Last 1000 chars
|
||||||
|
print("Return code:", result.returncode)
|
||||||
|
|
||||||
|
if compare_output and result.returncode == 0:
|
||||||
|
if not output_dir:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Output directory not configured. Expected one of "
|
||||||
|
"store_stage.writer.args.(seq_output_dir|obs_output_dir|output_dir)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
for root, dirs, files in os.walk(output_dir):
|
||||||
|
if "lmdb" in dirs and "meta_info.pkl" in files:
|
||||||
|
output_dir = root
|
||||||
|
break
|
||||||
|
|
||||||
|
if os.path.exists(reference_dir):
|
||||||
|
# Find the reference render directory
|
||||||
|
for root, dirs, files in os.walk(reference_dir):
|
||||||
|
if "lmdb" in dirs and "meta_info.pkl" in files:
|
||||||
|
reference_dir = root
|
||||||
|
break
|
||||||
|
|
||||||
|
# Build comparator command according to requested comparator
|
||||||
|
comp = (comparator or "simbox").lower()
|
||||||
|
comparator_args = comparator_args or {}
|
||||||
|
|
||||||
|
if comp == "simbox":
|
||||||
|
compare_cmd = [
|
||||||
|
"/isaac-sim/python.sh",
|
||||||
|
"tests/integration/data_comparators/simbox_comparator.py",
|
||||||
|
"--dir1",
|
||||||
|
output_dir,
|
||||||
|
"--dir2",
|
||||||
|
reference_dir,
|
||||||
|
]
|
||||||
|
# Optional numeric/image thresholds
|
||||||
|
if "tolerance" in comparator_args and comparator_args["tolerance"] is not None:
|
||||||
|
compare_cmd += ["--tolerance", str(comparator_args["tolerance"])]
|
||||||
|
if "image_psnr" in comparator_args and comparator_args["image_psnr"] is not None:
|
||||||
|
compare_cmd += ["--image-psnr", str(comparator_args["image_psnr"])]
|
||||||
|
|
||||||
|
print(f"Running comparison: {' '.join(compare_cmd)}")
|
||||||
|
compare_result = subprocess.run(
|
||||||
|
compare_cmd, capture_output=True, text=True, timeout=600, check=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Comparison STDOUT:")
|
||||||
|
print(compare_result.stdout)
|
||||||
|
print("Comparison STDERR:")
|
||||||
|
print(compare_result.stderr)
|
||||||
|
|
||||||
|
if compare_result.returncode != 0:
|
||||||
|
raise RuntimeError("Simbox comparison failed: outputs differ")
|
||||||
|
|
||||||
|
if "Successfully loaded data from both directories" in compare_result.stdout:
|
||||||
|
print("✓ Both output directories have valid structure (meta_info.pkl + lmdb)")
|
||||||
|
|
||||||
|
print("✓ Simbox render test completed with numeric-aligned comparison")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown comparator: {comp}. Use 'simbox'.")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise RuntimeError(f"Test timed out after {timeout} seconds")
|
||||||
|
|
||||||
|
def validate_output_generated(self, min_files: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that output was generated in the expected directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_files: Minimum number of files expected in output (default: 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if validation passes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If output directory doesn't exist or has too few files
|
||||||
|
"""
|
||||||
|
assert self.output_dir is not None, "Output directory not configured"
|
||||||
|
assert os.path.exists(self.output_dir), f"Expected output directory was not created: {self.output_dir}"
|
||||||
|
|
||||||
|
output_files = list(Path(self.output_dir).rglob("*"))
|
||||||
|
assert (
|
||||||
|
len(output_files) >= min_files
|
||||||
|
), f"Expected at least {min_files} files but found {len(output_files)} in: {self.output_dir}"
|
||||||
|
|
||||||
|
print(f"✓ Pipeline completed successfully. Output generated in: {self.output_dir}")
|
||||||
|
print(f"✓ Generated {len(output_files)} files/directories")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def compare_with_reference(self, reference_dir: str, comparator_func, **comparator_kwargs) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Compare generated output with reference data using provided comparator function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference_dir: Path to reference data directory
|
||||||
|
comparator_func: Function to use for comparison
|
||||||
|
**comparator_kwargs: Additional arguments for comparator function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (success: bool, message: str)
|
||||||
|
"""
|
||||||
|
if not os.path.exists(reference_dir):
|
||||||
|
print(f"Reference directory not found, skipping comparison: {reference_dir}")
|
||||||
|
return True, "Reference data not available"
|
||||||
|
|
||||||
|
success, message = comparator_func(self.output_dir, reference_dir, **comparator_kwargs)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(f"✓ Results match reference data: {message}")
|
||||||
|
else:
|
||||||
|
print(f"✗ Result comparison failed: {message}")
|
||||||
|
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def run_test_end_to_end(
|
||||||
|
self,
|
||||||
|
reference_dir: Optional[str] = None,
|
||||||
|
comparator_func=None,
|
||||||
|
comparator_kwargs: Optional[Dict] = None,
|
||||||
|
min_output_files: int = 1,
|
||||||
|
**pipeline_kwargs,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Run a complete end-to-end test including comparison with reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference_dir: Optional path to reference data for comparison
|
||||||
|
comparator_func: Optional function for comparing results
|
||||||
|
comparator_kwargs: Optional kwargs for comparator function
|
||||||
|
min_output_files: Minimum expected output files
|
||||||
|
**pipeline_kwargs: Additional arguments for pipeline execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if test passes, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If any test step fails
|
||||||
|
"""
|
||||||
|
# Load configuration
|
||||||
|
self.load_and_process_config()
|
||||||
|
# Run pipeline
|
||||||
|
self.run_data_engine_direct(**pipeline_kwargs, master_seed=self.seed)
|
||||||
|
|
||||||
|
# Validate output
|
||||||
|
self.validate_output_generated(min_files=min_output_files)
|
||||||
|
|
||||||
|
# Compare with reference if provided
|
||||||
|
if reference_dir and comparator_func:
|
||||||
|
comparator_kwargs = comparator_kwargs or {}
|
||||||
|
success, message = self.compare_with_reference(reference_dir, comparator_func, **comparator_kwargs)
|
||||||
|
assert success, f"Comparison with reference failed: {message}"
|
||||||
|
|
||||||
|
return True
|
||||||
52
tests/integration/base/utils.py
Normal file
52
tests/integration/base/utils.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for integration tests, including centralized seed setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import open3d as o3d
|
||||||
|
except ImportError:
|
||||||
|
o3d = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_test_seeds(seed):
|
||||||
|
"""
|
||||||
|
Set random seeds for all relevant libraries to ensure reproducible results.
|
||||||
|
|
||||||
|
This function sets seeds for:
|
||||||
|
- Python's random module
|
||||||
|
- NumPy
|
||||||
|
- PyTorch (if available)
|
||||||
|
- Open3D (if available)
|
||||||
|
- PyTorch CUDA (if available)
|
||||||
|
- PyTorch CUDNN settings for deterministic behavior
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): The seed value to use for all random number generators
|
||||||
|
"""
|
||||||
|
# Set Python's built-in random seed
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Set NumPy seed
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
# Set PyTorch seeds if available
|
||||||
|
if torch is not None:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
# Configure CUDNN for deterministic behavior
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
# Set Open3D seed if available
|
||||||
|
if o3d is not None:
|
||||||
|
o3d.utility.random.seed(seed)
|
||||||
8
tests/integration/data_comparators/__init__.py
Normal file
8
tests/integration/data_comparators/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Data comparators module for integration tests.
|
||||||
|
Provides functions to compare generated data with reference data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .sequence_comparator import compare_navigation_results
|
||||||
|
|
||||||
|
__all__ = ['compare_navigation_results']
|
||||||
551
tests/integration/data_comparators/sequence_comparator.py
Normal file
551
tests/integration/data_comparators/sequence_comparator.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
"""
|
||||||
|
Sequence data comparator for navigation pipeline tests.
|
||||||
|
Provides functions to compare generated navigation sequences with reference data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import cv2 # OpenCV is available per requirements
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def compare_navigation_results(generated_dir: str, reference_dir: str) -> Tuple[bool, str]:
|
||||||
|
"""Original JSON trajectory sequence comparison (unchanged logic).
|
||||||
|
|
||||||
|
NOTE: Do not modify this function's core behavior. Image comparison is handled by a separate
|
||||||
|
wrapper function `compare_navigation_and_images` to avoid side effects on existing tests.
|
||||||
|
"""
|
||||||
|
# --- Enhanced logic ---
|
||||||
|
# To support both "caller passes seq_path root directory" and "legacy call (leaf trajectory directory)" forms,
|
||||||
|
# here we use a symmetric data.json discovery strategy for both generated and reference sides:
|
||||||
|
# 1. If the current directory directly contains data.json, use that file.
|
||||||
|
# 2. Otherwise, traverse one level down into subdirectories (sorted alphabetically), looking for <dir>/data.json.
|
||||||
|
# 3. Otherwise, search within two nested levels (dir/subdir/data.json) and use the first match found.
|
||||||
|
# 4. If not found, report an error. This is compatible with the legacy "generated=root, reference=leaf" usage,
|
||||||
|
# and also allows both sides to provide the root directory.
|
||||||
|
|
||||||
|
if not os.path.isdir(generated_dir):
|
||||||
|
return False, f"Generated directory does not exist or is not a directory: {generated_dir}"
|
||||||
|
if not os.path.isdir(reference_dir):
|
||||||
|
return False, f"Reference directory does not exist or is not a directory: {reference_dir}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
generated_file = _locate_first_data_json(generated_dir)
|
||||||
|
if generated_file is None:
|
||||||
|
return False, f"Could not locate data.json under generated directory: {generated_dir}"
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
return False, f"Error locating generated data file: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
reference_file = _locate_first_data_json(reference_dir)
|
||||||
|
if reference_file is None:
|
||||||
|
# To preserve legacy behavior, if reference_dir/data.json exists but was not detected above (should not happen in theory), check once more
|
||||||
|
candidate = os.path.join(reference_dir, "data.json")
|
||||||
|
if os.path.isfile(candidate):
|
||||||
|
reference_file = candidate
|
||||||
|
else:
|
||||||
|
return False, f"Could not locate data.json under reference directory: {reference_dir}"
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
return False, f"Error locating reference data file: {e}"
|
||||||
|
|
||||||
|
return compare_trajectory_sequences(generated_file, reference_file)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_navigation_and_images(
|
||||||
|
generated_seq_dir: str,
|
||||||
|
reference_seq_dir: str,
|
||||||
|
generated_root_for_images: Optional[str] = None,
|
||||||
|
reference_root_for_images: Optional[str] = None,
|
||||||
|
rgb_abs_tolerance: int = 0,
|
||||||
|
depth_abs_tolerance: float = 0.0,
|
||||||
|
allowed_rgb_diff_ratio: float = 0.0,
|
||||||
|
allowed_depth_diff_ratio: float = 0.5,
|
||||||
|
depth_scale_auto: bool = False,
|
||||||
|
fail_if_images_missing: bool = False,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""Wrapper that preserves original JSON comparison while optionally adding first-frame image comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_seq_dir: Path to generated seq_path root used by original comparator.
|
||||||
|
reference_seq_dir: Path to reference seq_path root.
|
||||||
|
generated_root_for_images: Root (parent of obs_path) or the obs_path itself for generated images.
|
||||||
|
reference_root_for_images: Same as above for reference. If None, image comparison may be skipped.
|
||||||
|
rgb_abs_tolerance: RGB absolute per-channel tolerance.
|
||||||
|
depth_abs_tolerance: Depth absolute tolerance.
|
||||||
|
allowed_rgb_diff_ratio: Allowed differing RGB pixel ratio.
|
||||||
|
allowed_depth_diff_ratio: Allowed differing depth pixel ratio.
|
||||||
|
depth_scale_auto: Auto scale depth if uint16 millimeters.
|
||||||
|
fail_if_images_missing: If True, treat missing obs_path as failure; otherwise skip.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, message) combined result.
|
||||||
|
"""
|
||||||
|
traj_ok, traj_msg = compare_navigation_results(generated_seq_dir, reference_seq_dir)
|
||||||
|
|
||||||
|
# Determine image roots; default to parent of seq_dir if not explicitly provided
|
||||||
|
gen_img_root = generated_root_for_images or os.path.dirname(generated_seq_dir.rstrip(os.sep))
|
||||||
|
ref_img_root = reference_root_for_images or os.path.dirname(reference_seq_dir.rstrip(os.sep))
|
||||||
|
|
||||||
|
img_ok = True
|
||||||
|
img_msg = "image comparison skipped"
|
||||||
|
|
||||||
|
if generated_root_for_images is not None or reference_root_for_images is not None:
|
||||||
|
# User explicitly passed at least one root -> attempt compare
|
||||||
|
img_ok, img_msg = compare_first_frame_images(
|
||||||
|
generated_root=gen_img_root,
|
||||||
|
reference_root=ref_img_root,
|
||||||
|
rgb_abs_tolerance=rgb_abs_tolerance,
|
||||||
|
depth_abs_tolerance=depth_abs_tolerance,
|
||||||
|
allowed_rgb_diff_ratio=allowed_rgb_diff_ratio,
|
||||||
|
allowed_depth_diff_ratio=allowed_depth_diff_ratio,
|
||||||
|
depth_scale_auto=depth_scale_auto,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Implicit attempt only if both obs_path exist under parent paths
|
||||||
|
gen_obs_candidate = os.path.join(gen_img_root, "obs_path")
|
||||||
|
ref_obs_candidate = os.path.join(ref_img_root, "obs_path")
|
||||||
|
if os.path.isdir(gen_obs_candidate) and os.path.isdir(ref_obs_candidate):
|
||||||
|
img_ok, img_msg = compare_first_frame_images(
|
||||||
|
generated_root=gen_img_root,
|
||||||
|
reference_root=ref_img_root,
|
||||||
|
rgb_abs_tolerance=rgb_abs_tolerance,
|
||||||
|
depth_abs_tolerance=depth_abs_tolerance,
|
||||||
|
allowed_rgb_diff_ratio=allowed_rgb_diff_ratio,
|
||||||
|
allowed_depth_diff_ratio=allowed_depth_diff_ratio,
|
||||||
|
depth_scale_auto=depth_scale_auto,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if fail_if_images_missing:
|
||||||
|
missing = []
|
||||||
|
if not os.path.isdir(gen_obs_candidate):
|
||||||
|
missing.append(f"generated:{gen_obs_candidate}")
|
||||||
|
if not os.path.isdir(ref_obs_candidate):
|
||||||
|
missing.append(f"reference:{ref_obs_candidate}")
|
||||||
|
img_ok = False
|
||||||
|
img_msg = "obs_path missing -> " + ", ".join(missing)
|
||||||
|
else:
|
||||||
|
img_msg = "obs_path not found in one or both roots; skipped"
|
||||||
|
|
||||||
|
overall = traj_ok and img_ok
|
||||||
|
message = f"trajectory: {traj_msg}; images: {img_msg}"
|
||||||
|
return overall, message if overall else f"Mismatch - {message}"
|
||||||
|
|
||||||
|
|
||||||
|
def compare_trajectory_sequences(generated_file: str, reference_file: str, tolerance: float = 1e-6) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Compare trajectory sequence files with numerical tolerance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_file: Path to generated trajectory file
|
||||||
|
reference_file: Path to reference trajectory file
|
||||||
|
tolerance: Numerical tolerance for floating point comparisons
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (success, message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if files exist
|
||||||
|
if not os.path.exists(generated_file):
|
||||||
|
return False, f"Generated file does not exist: {generated_file}"
|
||||||
|
|
||||||
|
if not os.path.exists(reference_file):
|
||||||
|
return False, f"Reference file does not exist: {reference_file}"
|
||||||
|
|
||||||
|
# Load JSON files
|
||||||
|
with open(generated_file, 'r') as f:
|
||||||
|
generated_data = json.load(f)
|
||||||
|
|
||||||
|
with open(reference_file, 'r') as f:
|
||||||
|
reference_data = json.load(f)
|
||||||
|
|
||||||
|
# Compare the JSON structures
|
||||||
|
success, message = _compare_data_structures(generated_data, reference_data, tolerance)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return True, "Trajectory sequences match within tolerance"
|
||||||
|
else:
|
||||||
|
return False, f"Trajectory sequences differ: {message}"
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return False, f"JSON decode error: {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"Error comparing trajectory sequences: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_data_structures(data1: Any, data2: Any, tolerance: float, path: str = "") -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Recursively compare two data structures with numerical tolerance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data1: First data structure
|
||||||
|
data2: Second data structure
|
||||||
|
tolerance: Numerical tolerance for floating point comparisons
|
||||||
|
path: Current path in the data structure for error reporting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (success, error_message)
|
||||||
|
"""
|
||||||
|
# Check if types are the same
|
||||||
|
if type(data1) != type(data2):
|
||||||
|
return False, f"Type mismatch at {path}: {type(data1)} vs {type(data2)}"
|
||||||
|
|
||||||
|
# Handle dictionaries
|
||||||
|
if isinstance(data1, dict):
|
||||||
|
if set(data1.keys()) != set(data2.keys()):
|
||||||
|
return False, f"Key mismatch at {path}: {set(data1.keys())} vs {set(data2.keys())}"
|
||||||
|
|
||||||
|
for key in data1.keys():
|
||||||
|
new_path = f"{path}.{key}" if path else key
|
||||||
|
success, message = _compare_data_structures(data1[key], data2[key], tolerance, new_path)
|
||||||
|
if not success:
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
# Handle lists
|
||||||
|
elif isinstance(data1, list):
|
||||||
|
if len(data1) != len(data2):
|
||||||
|
return False, f"List length mismatch at {path}: {len(data1)} vs {len(data2)}"
|
||||||
|
|
||||||
|
for i, (item1, item2) in enumerate(zip(data1, data2)):
|
||||||
|
new_path = f"{path}[{i}]" if path else f"[{i}]"
|
||||||
|
success, message = _compare_data_structures(item1, item2, tolerance, new_path)
|
||||||
|
if not success:
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
# Handle numerical values
|
||||||
|
elif isinstance(data1, (int, float)):
|
||||||
|
if isinstance(data2, (int, float)):
|
||||||
|
if abs(data1 - data2) > tolerance:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Numerical difference at {path}: {data1} vs {data2} (diff: {abs(data1 - data2)}, tolerance: {tolerance})",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return False, f"Type mismatch at {path}: number vs {type(data2)}"
|
||||||
|
|
||||||
|
# Handle strings and other exact comparison types
|
||||||
|
elif isinstance(data1, (str, bool, type(None))):
|
||||||
|
if data1 != data2:
|
||||||
|
return False, f"Value mismatch at {path}: {data1} vs {data2}"
|
||||||
|
|
||||||
|
# Handle unknown types
|
||||||
|
else:
|
||||||
|
if data1 != data2:
|
||||||
|
return False, f"Value mismatch at {path}: {data1} vs {data2}"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def _locate_first_data_json(root: str) -> Optional[str]:
|
||||||
|
"""Locate a data.json file under root with a shallow, deterministic strategy.
|
||||||
|
|
||||||
|
Strategy (stop at first match to keep behavior predictable & lightweight):
|
||||||
|
1. If root/data.json exists -> return it.
|
||||||
|
2. Enumerate immediate subdirectories (sorted). For each d:
|
||||||
|
- if d/data.json exists -> return it.
|
||||||
|
3. Enumerate immediate subdirectories again; for each d enumerate its subdirectories (sorted) and
|
||||||
|
look for d/sub/data.json -> return first match.
|
||||||
|
4. If none found -> return None.
|
||||||
|
"""
|
||||||
|
# 1. root/data.json
|
||||||
|
candidate = os.path.join(root, "data.json")
|
||||||
|
if os.path.isfile(candidate):
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_level = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_level.sort()
|
||||||
|
|
||||||
|
# 2. d/data.json
|
||||||
|
for d in first_level:
|
||||||
|
c = os.path.join(root, d, "data.json")
|
||||||
|
if os.path.isfile(c):
|
||||||
|
return c
|
||||||
|
|
||||||
|
# 3. d/sub/data.json
|
||||||
|
for d in first_level:
|
||||||
|
d_path = os.path.join(root, d)
|
||||||
|
try:
|
||||||
|
second_level = [s for s in os.listdir(d_path) if os.path.isdir(os.path.join(d_path, s))]
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
second_level.sort()
|
||||||
|
for s in second_level:
|
||||||
|
c = os.path.join(d_path, s, "data.json")
|
||||||
|
if os.path.isfile(c):
|
||||||
|
return c
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def compare_first_frame_images(
|
||||||
|
generated_root: str,
|
||||||
|
reference_root: str,
|
||||||
|
rgb_dir_name: str = "rgb",
|
||||||
|
depth_dir_name: str = "depth",
|
||||||
|
scene_dir: Optional[str] = None,
|
||||||
|
traj_dir: Optional[str] = None,
|
||||||
|
rgb_abs_tolerance: int = 0,
|
||||||
|
depth_abs_tolerance: float = 0.0,
|
||||||
|
allowed_rgb_diff_ratio: float = 0.0,
|
||||||
|
allowed_depth_diff_ratio: float = 0.0,
|
||||||
|
compute_psnr: bool = True,
|
||||||
|
compute_mse: bool = True,
|
||||||
|
depth_scale_auto: bool = False,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""Compare only the first frame (index 0) of RGB & depth images between generated and reference.
|
||||||
|
|
||||||
|
This is a lightweight check to validate pipeline correctness without scanning all frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_root: Path to generated run root (may contain `obs_path` or be the `obs_path`).
|
||||||
|
reference_root: Path to reference run root (same structure as generated_root).
|
||||||
|
rgb_dir_name: Subdirectory name for RGB frames under a trajectory directory.
|
||||||
|
depth_dir_name: Subdirectory name for depth frames under a trajectory directory.
|
||||||
|
scene_dir: Optional explicit scene directory name (e.g. "6f"); if None will auto-pick first.
|
||||||
|
traj_dir: Optional explicit trajectory directory (e.g. "0"); if None will auto-pick first.
|
||||||
|
rgb_abs_tolerance: Per-channel absolute pixel tolerance (0 requires exact match).
|
||||||
|
depth_abs_tolerance: Absolute tolerance for depth value differences (after optional scaling).
|
||||||
|
allowed_rgb_diff_ratio: Max allowed ratio of differing RGB pixels (0.01 -> 1%).
|
||||||
|
allowed_depth_diff_ratio: Max allowed ratio of differing depth pixels beyond tolerance.
|
||||||
|
compute_psnr: Whether to compute PSNR metric for reporting.
|
||||||
|
compute_mse: Whether to compute MSE metric for reporting.
|
||||||
|
depth_scale_auto: If True, attempt simple heuristic scaling for uint16 depth (divide by 1000 if max > 10000).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, message) summary of comparison.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
gen_obs = _resolve_obs_path(generated_root)
|
||||||
|
ref_obs = _resolve_obs_path(reference_root)
|
||||||
|
if gen_obs is None:
|
||||||
|
return False, f"Cannot locate obs_path under generated root: {generated_root}"
|
||||||
|
if ref_obs is None:
|
||||||
|
return False, f"Cannot locate obs_path under reference root: {reference_root}"
|
||||||
|
|
||||||
|
scene_dir = scene_dir or _pick_first_subdir(gen_obs)
|
||||||
|
if scene_dir is None:
|
||||||
|
return False, f"No scene directory found in {gen_obs}"
|
||||||
|
ref_scene_dir = scene_dir if os.path.isdir(os.path.join(ref_obs, scene_dir)) else _pick_first_subdir(ref_obs)
|
||||||
|
if ref_scene_dir is None:
|
||||||
|
return False, f"No matching scene directory in reference: {ref_obs}"
|
||||||
|
|
||||||
|
gen_scene_path = os.path.join(gen_obs, scene_dir)
|
||||||
|
ref_scene_path = os.path.join(ref_obs, ref_scene_dir)
|
||||||
|
|
||||||
|
traj_dir = traj_dir or _pick_first_subdir(gen_scene_path)
|
||||||
|
if traj_dir is None:
|
||||||
|
return False, f"No trajectory directory in {gen_scene_path}"
|
||||||
|
ref_traj_dir = (
|
||||||
|
traj_dir if os.path.isdir(os.path.join(ref_scene_path, traj_dir)) else _pick_first_subdir(ref_scene_path)
|
||||||
|
)
|
||||||
|
if ref_traj_dir is None:
|
||||||
|
return False, f"No trajectory directory in reference scene path {ref_scene_path}"
|
||||||
|
|
||||||
|
gen_traj_path = os.path.join(gen_scene_path, traj_dir)
|
||||||
|
ref_traj_path = os.path.join(ref_scene_path, ref_traj_dir)
|
||||||
|
|
||||||
|
# RGB comparison
|
||||||
|
rgb_result, rgb_msg = _compare_single_frame_rgb(
|
||||||
|
gen_traj_path,
|
||||||
|
ref_traj_path,
|
||||||
|
rgb_dir_name,
|
||||||
|
rgb_abs_tolerance,
|
||||||
|
allowed_rgb_diff_ratio,
|
||||||
|
compute_psnr,
|
||||||
|
compute_mse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Depth comparison (optional if depth folder exists)
|
||||||
|
depth_result, depth_msg = _compare_single_frame_depth(
|
||||||
|
gen_traj_path,
|
||||||
|
ref_traj_path,
|
||||||
|
depth_dir_name,
|
||||||
|
depth_abs_tolerance,
|
||||||
|
allowed_depth_diff_ratio,
|
||||||
|
compute_psnr,
|
||||||
|
compute_mse,
|
||||||
|
depth_scale_auto,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = rgb_result and depth_result
|
||||||
|
combined_msg = f"RGB: {rgb_msg}; Depth: {depth_msg}"
|
||||||
|
return success, ("Images match - " + combined_msg) if success else ("Image mismatch - " + combined_msg)
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
return False, f"Error during first-frame image comparison: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_obs_path(root: str) -> Optional[str]:
|
||||||
|
"""Return the obs_path directory. Accept either the root itself or its child."""
|
||||||
|
if not os.path.isdir(root):
|
||||||
|
return None
|
||||||
|
if os.path.basename(root) == "obs_path":
|
||||||
|
return root
|
||||||
|
candidate = os.path.join(root, "obs_path")
|
||||||
|
return candidate if os.path.isdir(candidate) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_first_subdir(parent: str) -> Optional[str]:
|
||||||
|
"""Pick the first alphanumerically sorted subdirectory name under parent."""
|
||||||
|
try:
|
||||||
|
subs = [d for d in os.listdir(parent) if os.path.isdir(os.path.join(parent, d))]
|
||||||
|
if not subs:
|
||||||
|
return None
|
||||||
|
subs.sort()
|
||||||
|
return subs[0]
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_first_frame_file(folder: str, exts: Tuple[str, ...]) -> Optional[str]:
|
||||||
|
"""Find the smallest numeral file with one of extensions; returns absolute path."""
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
return None
|
||||||
|
candidates = []
|
||||||
|
for f in os.listdir(folder):
|
||||||
|
lower = f.lower()
|
||||||
|
for e in exts:
|
||||||
|
if lower.endswith(e):
|
||||||
|
num_part = os.path.splitext(f)[0]
|
||||||
|
if num_part.isdigit():
|
||||||
|
candidates.append((int(num_part), f))
|
||||||
|
elif f.startswith("0"): # fallback for names like 0.jpg
|
||||||
|
candidates.append((0, f))
|
||||||
|
break
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
candidates.sort(key=lambda x: x[0])
|
||||||
|
return os.path.join(folder, candidates[0][1])
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_single_frame_rgb(
|
||||||
|
gen_traj_path: str,
|
||||||
|
ref_traj_path: str,
|
||||||
|
rgb_dir_name: str,
|
||||||
|
abs_tol: int,
|
||||||
|
allowed_ratio: float,
|
||||||
|
compute_psnr: bool,
|
||||||
|
compute_mse: bool,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
rgb_gen_dir = os.path.join(gen_traj_path, rgb_dir_name)
|
||||||
|
rgb_ref_dir = os.path.join(ref_traj_path, rgb_dir_name)
|
||||||
|
if not os.path.isdir(rgb_gen_dir) or not os.path.isdir(rgb_ref_dir):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"RGB directory missing (generated: {os.path.isdir(rgb_gen_dir)}, reference: {os.path.isdir(rgb_ref_dir)})",
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_file = _find_first_frame_file(rgb_gen_dir, (".jpg", ".png", ".jpeg"))
|
||||||
|
ref_file = _find_first_frame_file(rgb_ref_dir, (".jpg", ".png", ".jpeg"))
|
||||||
|
if not gen_file or not ref_file:
|
||||||
|
return False, "First RGB frame file not found in one of the directories"
|
||||||
|
|
||||||
|
gen_img = cv2.imread(gen_file, cv2.IMREAD_COLOR)
|
||||||
|
ref_img = cv2.imread(ref_file, cv2.IMREAD_COLOR)
|
||||||
|
if gen_img is None or ref_img is None:
|
||||||
|
return False, "Failed to read RGB images"
|
||||||
|
if gen_img.shape != ref_img.shape:
|
||||||
|
return False, f"RGB shape mismatch {gen_img.shape} vs {ref_img.shape}"
|
||||||
|
|
||||||
|
diff = np.abs(gen_img.astype(np.int16) - ref_img.astype(np.int16))
|
||||||
|
diff_mask = np.any(diff > abs_tol, axis=2)
|
||||||
|
diff_ratio = float(diff_mask.sum()) / diff_mask.size
|
||||||
|
|
||||||
|
metrics_parts = [f"diff_pixels_ratio={diff_ratio:.4f}"]
|
||||||
|
flag = False
|
||||||
|
if compute_mse or compute_psnr:
|
||||||
|
mse = float((diff**2).mean())
|
||||||
|
if compute_mse:
|
||||||
|
metrics_parts.append(f"mse={mse:.2f}")
|
||||||
|
if compute_psnr:
|
||||||
|
if mse == 0.0:
|
||||||
|
psnr = float('inf')
|
||||||
|
flag = True
|
||||||
|
else:
|
||||||
|
psnr = 10.0 * math.log10((255.0**2) / mse)
|
||||||
|
if math.isinf(psnr):
|
||||||
|
metrics_parts.append("psnr=inf")
|
||||||
|
flag = True
|
||||||
|
else:
|
||||||
|
metrics_parts.append(f"psnr={psnr:.2f}dB")
|
||||||
|
if psnr >= 40.0:
|
||||||
|
flag = True
|
||||||
|
|
||||||
|
passed = diff_ratio <= allowed_ratio or flag
|
||||||
|
status = "OK" if passed else "FAIL"
|
||||||
|
return passed, f"{status} (abs_tol={abs_tol}, allowed_ratio={allowed_ratio}, {' '.join(metrics_parts)})"
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_single_frame_depth(
|
||||||
|
gen_traj_path: str,
|
||||||
|
ref_traj_path: str,
|
||||||
|
depth_dir_name: str,
|
||||||
|
abs_tol: float,
|
||||||
|
allowed_ratio: float,
|
||||||
|
compute_psnr: bool,
|
||||||
|
compute_mse: bool,
|
||||||
|
auto_scale: bool,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
depth_gen_dir = os.path.join(gen_traj_path, depth_dir_name)
|
||||||
|
depth_ref_dir = os.path.join(ref_traj_path, depth_dir_name)
|
||||||
|
if not os.path.isdir(depth_gen_dir) or not os.path.isdir(depth_ref_dir):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Depth directory missing (generated: {os.path.isdir(depth_gen_dir)}, reference: {os.path.isdir(depth_ref_dir)})",
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_file = _find_first_frame_file(depth_gen_dir, (".png", ".exr"))
|
||||||
|
ref_file = _find_first_frame_file(depth_ref_dir, (".png", ".exr"))
|
||||||
|
if not gen_file or not ref_file:
|
||||||
|
return False, "First depth frame file not found in one of the directories"
|
||||||
|
|
||||||
|
gen_img = cv2.imread(gen_file, cv2.IMREAD_UNCHANGED)
|
||||||
|
ref_img = cv2.imread(ref_file, cv2.IMREAD_UNCHANGED)
|
||||||
|
if gen_img is None or ref_img is None:
|
||||||
|
return False, "Failed to read depth images"
|
||||||
|
if gen_img.shape != ref_img.shape:
|
||||||
|
return False, f"Depth shape mismatch {gen_img.shape} vs {ref_img.shape}"
|
||||||
|
|
||||||
|
gen_depth = _prepare_depth_array(gen_img, auto_scale)
|
||||||
|
ref_depth = _prepare_depth_array(ref_img, auto_scale)
|
||||||
|
if gen_depth.shape != ref_depth.shape:
|
||||||
|
return False, f"Depth array shape mismatch {gen_depth.shape} vs {ref_depth.shape}"
|
||||||
|
|
||||||
|
diff = np.abs(gen_depth - ref_depth)
|
||||||
|
diff_mask = diff > abs_tol
|
||||||
|
diff_ratio = float(diff_mask.sum()) / diff_mask.size
|
||||||
|
|
||||||
|
metrics_parts = [f"diff_pixels_ratio={diff_ratio:.4f}"]
|
||||||
|
if compute_mse or compute_psnr:
|
||||||
|
mse = float((diff**2).mean())
|
||||||
|
if compute_mse:
|
||||||
|
metrics_parts.append(f"mse={mse:.6f}")
|
||||||
|
if compute_psnr:
|
||||||
|
# Estimate dynamic range from reference depth
|
||||||
|
dr = float(ref_depth.max() - ref_depth.min()) or 1.0
|
||||||
|
if mse == 0.0:
|
||||||
|
psnr = float('inf')
|
||||||
|
else:
|
||||||
|
psnr = 10.0 * math.log10((dr**2) / mse)
|
||||||
|
metrics_parts.append("psnr=inf" if math.isinf(psnr) else f"psnr={psnr:.2f}dB")
|
||||||
|
|
||||||
|
passed = diff_ratio <= allowed_ratio
|
||||||
|
status = "OK" if passed else "FAIL"
|
||||||
|
return passed, f"{status} (abs_tol={abs_tol}, allowed_ratio={allowed_ratio}, {' '.join(metrics_parts)})"
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_depth_array(arr: np.ndarray, auto_scale: bool) -> np.ndarray:
|
||||||
|
"""Convert raw depth image to float32 array; apply simple heuristic scaling if requested."""
|
||||||
|
if arr.dtype == np.uint16:
|
||||||
|
depth = arr.astype(np.float32)
|
||||||
|
if auto_scale and depth.max() > 10000: # likely millimeters
|
||||||
|
depth /= 1000.0
|
||||||
|
return depth
|
||||||
|
if arr.dtype == np.float32:
|
||||||
|
return arr
|
||||||
|
# Fallback: convert to float
|
||||||
|
return arr.astype(np.float32)
|
||||||
398
tests/integration/data_comparators/simbox_comparator.py
Normal file
398
tests/integration/data_comparators/simbox_comparator.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
"""
|
||||||
|
Simbox Output Comparator
|
||||||
|
|
||||||
|
This module provides functionality to compare two Simbox task output directories.
|
||||||
|
It compares both meta_info.pkl and LMDB database contents, handling different data types:
|
||||||
|
- JSON data (dict/list)
|
||||||
|
- Scalar data (numerical arrays/lists)
|
||||||
|
- Image data (encoded images)
|
||||||
|
- Proprioception data (joint states, gripper states)
|
||||||
|
- Object data (object poses and properties)
|
||||||
|
- Action data (robot actions)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import lmdb
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SimboxComparator:
|
||||||
|
"""Comparator for Simbox task output directories."""
|
||||||
|
|
||||||
|
def __init__(self, dir1: str, dir2: str, tolerance: float = 1e-6, image_psnr_threshold: float = 30.0):
|
||||||
|
"""
|
||||||
|
Initialize the comparator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir1: Path to the first output directory
|
||||||
|
dir2: Path to the second output directory
|
||||||
|
tolerance: Numerical tolerance for floating point comparisons
|
||||||
|
image_psnr_threshold: PSNR threshold (dB) for considering images as acceptable match
|
||||||
|
"""
|
||||||
|
self.dir1 = Path(dir1)
|
||||||
|
self.dir2 = Path(dir2)
|
||||||
|
self.tolerance = tolerance
|
||||||
|
self.image_psnr_threshold = image_psnr_threshold
|
||||||
|
self.mismatches = []
|
||||||
|
self.warnings = []
|
||||||
|
self.image_psnr_values: List[float] = []
|
||||||
|
|
||||||
|
def load_directory(self, directory: Path) -> Tuple[Optional[Dict], Optional[Any], Optional[Any]]:
|
||||||
|
"""Load meta_info.pkl and LMDB database from directory."""
|
||||||
|
meta_path = directory / "meta_info.pkl"
|
||||||
|
lmdb_path = directory / "lmdb"
|
||||||
|
|
||||||
|
if not directory.is_dir() or not meta_path.exists() or not lmdb_path.is_dir():
|
||||||
|
print(f"Error: '{directory}' is not a valid output directory.")
|
||||||
|
print("It must contain 'meta_info.pkl' and an 'lmdb' subdirectory.")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_path, "rb") as f:
|
||||||
|
meta_info = pickle.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading meta_info.pkl from {directory}: {e}")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = lmdb.open(str(lmdb_path), readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
txn = env.begin(write=False)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error opening LMDB database at {lmdb_path}: {e}")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
return meta_info, txn, env
|
||||||
|
|
||||||
|
def compare_metadata(self, meta1: Dict, meta2: Dict) -> bool:
|
||||||
|
"""Compare high-level metadata."""
|
||||||
|
identical = True
|
||||||
|
|
||||||
|
if meta1.get("num_steps") != meta2.get("num_steps"):
|
||||||
|
self.mismatches.append(f"num_steps differ: {meta1.get('num_steps')} vs {meta2.get('num_steps')}")
|
||||||
|
identical = False
|
||||||
|
|
||||||
|
return identical
|
||||||
|
|
||||||
|
def get_key_categories(self, meta: Dict) -> Dict[str, set]:
|
||||||
|
"""Extract key categories from metadata."""
|
||||||
|
key_to_category = {}
|
||||||
|
for category, keys in meta.get("keys", {}).items():
|
||||||
|
for key in keys:
|
||||||
|
key_bytes = key if isinstance(key, bytes) else key.encode()
|
||||||
|
key_to_category[key_bytes] = category
|
||||||
|
|
||||||
|
return key_to_category
|
||||||
|
|
||||||
|
def compare_json_data(self, key: bytes, data1: Any, data2: Any) -> bool:
|
||||||
|
"""Compare JSON/dict/list data."""
|
||||||
|
if type(data1) != type(data2):
|
||||||
|
self.mismatches.append(f"[{key.decode()}] Type mismatch: {type(data1).__name__} vs {type(data2).__name__}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(data1, dict):
|
||||||
|
if set(data1.keys()) != set(data2.keys()):
|
||||||
|
self.mismatches.append(f"[{key.decode()}] Dict keys differ")
|
||||||
|
return False
|
||||||
|
for k in data1.keys():
|
||||||
|
if not self.compare_json_data(key, data1[k], data2[k]):
|
||||||
|
return False
|
||||||
|
elif isinstance(data1, list):
|
||||||
|
if len(data1) != len(data2):
|
||||||
|
self.mismatches.append(f"[{key.decode()}] List length differ: {len(data1)} vs {len(data2)}")
|
||||||
|
return False
|
||||||
|
# For lists, compare sample elements to avoid excessive output
|
||||||
|
if len(data1) > 10:
|
||||||
|
sample_indices = [0, len(data1) // 2, -1]
|
||||||
|
for idx in sample_indices:
|
||||||
|
if not self.compare_json_data(key, data1[idx], data2[idx]):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
for i, (v1, v2) in enumerate(zip(data1, data2)):
|
||||||
|
if not self.compare_json_data(key, v1, v2):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if data1 != data2:
|
||||||
|
self.mismatches.append(f"[{key.decode()}] Value mismatch: {data1} vs {data2}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def compare_numerical_data(self, key: bytes, data1: Any, data2: Any) -> bool:
|
||||||
|
"""Compare numerical data (arrays, lists of numbers)."""
|
||||||
|
# Convert to numpy arrays for comparison
|
||||||
|
try:
|
||||||
|
if isinstance(data1, list):
|
||||||
|
arr1 = np.array(data1)
|
||||||
|
arr2 = np.array(data2)
|
||||||
|
else:
|
||||||
|
arr1 = data1
|
||||||
|
arr2 = data2
|
||||||
|
|
||||||
|
if arr1.shape != arr2.shape:
|
||||||
|
self.mismatches.append(f"[{key.decode()}] Shape mismatch: {arr1.shape} vs {arr2.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not np.allclose(arr1, arr2, rtol=self.tolerance, atol=self.tolerance):
|
||||||
|
diff = np.abs(arr1 - arr2)
|
||||||
|
max_diff = np.max(diff)
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
self.mismatches.append(
|
||||||
|
f"[{key.decode()}] Numerical difference: max={max_diff:.6e}, mean={mean_diff:.6e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.warnings.append(f"[{key.decode()}] Error comparing numerical data: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def compare_image_data(self, key: bytes, data1: np.ndarray, data2: np.ndarray) -> bool:
|
||||||
|
"""Compare image data (encoded as uint8 arrays)."""
|
||||||
|
try:
|
||||||
|
# Decode images
|
||||||
|
img1 = cv2.imdecode(data1, cv2.IMREAD_UNCHANGED)
|
||||||
|
img2 = cv2.imdecode(data2, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
if img1 is None or img2 is None:
|
||||||
|
self.warnings.append(f"[{key.decode()}] Could not decode image, using binary comparison")
|
||||||
|
return np.array_equal(data1, data2)
|
||||||
|
|
||||||
|
# Compare shapes
|
||||||
|
if img1.shape != img2.shape:
|
||||||
|
self.mismatches.append(f"[{key.decode()}] Image shape mismatch: {img1.shape} vs {img2.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Calculate PSNR for tracking average quality
|
||||||
|
if np.array_equal(img1, img2):
|
||||||
|
psnr = 100.0
|
||||||
|
else:
|
||||||
|
diff_float = img1.astype(np.float32) - img2.astype(np.float32)
|
||||||
|
mse = np.mean(diff_float ** 2)
|
||||||
|
if mse == 0:
|
||||||
|
psnr = 100.0
|
||||||
|
else:
|
||||||
|
max_pixel = 255.0
|
||||||
|
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
|
||||||
|
|
||||||
|
self.image_psnr_values.append(psnr)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"[{key.decode()}] PSNR: {psnr:.2f} dB")
|
||||||
|
except Exception:
|
||||||
|
print(f"[<binary key>] PSNR: {psnr:.2f} dB")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.warnings.append(f"[{key.decode()}] Error comparing image: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _save_comparison_image(self, key: bytes, img1: np.ndarray, img2: np.ndarray, diff: np.ndarray):
|
||||||
|
"""Save comparison visualization for differing images."""
|
||||||
|
try:
|
||||||
|
output_dir = Path("image_comparisons")
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Normalize difference for visualization
|
||||||
|
if len(diff.shape) == 3:
|
||||||
|
diff_vis = np.clip(diff * 5, 0, 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
diff_vis = np.clip(diff * 5, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Ensure RGB format for concatenation
|
||||||
|
def ensure_rgb(img):
|
||||||
|
if len(img.shape) == 2:
|
||||||
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif len(img.shape) == 3 and img.shape[2] == 4:
|
||||||
|
return cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||||
|
return img
|
||||||
|
|
||||||
|
img1_rgb = ensure_rgb(img1)
|
||||||
|
img2_rgb = ensure_rgb(img2)
|
||||||
|
diff_rgb = ensure_rgb(diff_vis)
|
||||||
|
|
||||||
|
# Concatenate horizontally
|
||||||
|
combined = np.hstack([img1_rgb, img2_rgb, diff_rgb])
|
||||||
|
|
||||||
|
# Save
|
||||||
|
safe_key = key.decode().replace("/", "_").replace("\\", "_").replace(":", "_")
|
||||||
|
output_path = output_dir / f"diff_{safe_key}.png"
|
||||||
|
cv2.imwrite(str(output_path), cv2.cvtColor(combined, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.warnings.append(f"Failed to save comparison image for {key.decode()}: {e}")
|
||||||
|
|
||||||
|
def compare_value(self, key: bytes, category: str, val1: bytes, val2: bytes) -> bool:
|
||||||
|
"""Compare a single key-value pair based on its category."""
|
||||||
|
# Output the category for the current key being compared
|
||||||
|
try:
|
||||||
|
print(f"[{key.decode()}] category: {category}")
|
||||||
|
except Exception:
|
||||||
|
print(f"[<binary key>] category: {category}")
|
||||||
|
|
||||||
|
if val1 is None and val2 is None:
|
||||||
|
return True
|
||||||
|
if val1 is None or val2 is None:
|
||||||
|
self.mismatches.append(f"[{key.decode()}] Key exists in one dataset but not the other")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
data1 = pickle.loads(val1)
|
||||||
|
data2 = pickle.loads(val2)
|
||||||
|
except Exception as e:
|
||||||
|
self.warnings.append(f"[{key.decode()}] Error unpickling data: {e}")
|
||||||
|
return val1 == val2
|
||||||
|
|
||||||
|
# Route to appropriate comparison based on category
|
||||||
|
if category == "json_data":
|
||||||
|
return self.compare_json_data(key, data1, data2)
|
||||||
|
elif category in ["scalar_data", "proprio_data", "object_data", "action_data"]:
|
||||||
|
return self.compare_numerical_data(key, data1, data2)
|
||||||
|
elif category.startswith("images."):
|
||||||
|
# Image data is stored as numpy uint8 array
|
||||||
|
if isinstance(data1, np.ndarray) and isinstance(data2, np.ndarray):
|
||||||
|
return self.compare_image_data(key, data1, data2)
|
||||||
|
else:
|
||||||
|
self.warnings.append(f"[{key.decode()}] Expected numpy array for image data")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Unknown category, try generic comparison
|
||||||
|
self.warnings.append(f"[{key.decode()}] Unknown category '{category}', using binary comparison")
|
||||||
|
return val1 == val2
|
||||||
|
|
||||||
|
def compare(self) -> bool:
|
||||||
|
"""Execute full comparison."""
|
||||||
|
print(f"Comparing directories:")
|
||||||
|
print(f" Dir1: {self.dir1}")
|
||||||
|
print(f" Dir2: {self.dir2}\n")
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
meta1, txn1, env1 = self.load_directory(self.dir1)
|
||||||
|
meta2, txn2, env2 = self.load_directory(self.dir2)
|
||||||
|
|
||||||
|
if meta1 is None or meta2 is None:
|
||||||
|
print("Aborting due to loading errors.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("Successfully loaded data from both directories.\n")
|
||||||
|
|
||||||
|
# Compare metadata
|
||||||
|
print("Comparing metadata...")
|
||||||
|
self.compare_metadata(meta1, meta2)
|
||||||
|
|
||||||
|
# Get key categories
|
||||||
|
key_cat1 = self.get_key_categories(meta1)
|
||||||
|
key_cat2 = self.get_key_categories(meta2)
|
||||||
|
|
||||||
|
keys1 = set(key_cat1.keys())
|
||||||
|
keys2 = set(key_cat2.keys())
|
||||||
|
|
||||||
|
# Check key sets
|
||||||
|
if keys1 != keys2:
|
||||||
|
missing_in_2 = sorted([k.decode() for k in keys1 - keys2])
|
||||||
|
missing_in_1 = sorted([k.decode() for k in keys2 - keys1])
|
||||||
|
if missing_in_2:
|
||||||
|
self.mismatches.append(f"Keys missing in dir2: {missing_in_2[:10]}")
|
||||||
|
if missing_in_1:
|
||||||
|
self.mismatches.append(f"Keys missing in dir1: {missing_in_1[:10]}")
|
||||||
|
|
||||||
|
# Compare common keys
|
||||||
|
common_keys = sorted(list(keys1.intersection(keys2)))
|
||||||
|
print(f"Comparing {len(common_keys)} common keys...\n")
|
||||||
|
|
||||||
|
for i, key in enumerate(common_keys):
|
||||||
|
if i % 100 == 0 and i > 0:
|
||||||
|
print(f"Progress: {i}/{len(common_keys)} keys compared...")
|
||||||
|
|
||||||
|
category = key_cat1.get(key, "unknown")
|
||||||
|
val1 = txn1.get(key)
|
||||||
|
val2 = txn2.get(key)
|
||||||
|
|
||||||
|
self.compare_value(key, category, val1, val2)
|
||||||
|
|
||||||
|
if self.image_psnr_values:
|
||||||
|
avg_psnr = sum(self.image_psnr_values) / len(self.image_psnr_values)
|
||||||
|
print(
|
||||||
|
f"\nImage PSNR average over {len(self.image_psnr_values)} images: "
|
||||||
|
f"{avg_psnr:.2f} dB (threshold {self.image_psnr_threshold:.2f} dB)"
|
||||||
|
)
|
||||||
|
if avg_psnr < self.image_psnr_threshold:
|
||||||
|
self.mismatches.append(
|
||||||
|
f"Average image PSNR {avg_psnr:.2f} dB below threshold "
|
||||||
|
f"{self.image_psnr_threshold:.2f} dB"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("\nNo image entries found for PSNR calculation.")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
env1.close()
|
||||||
|
env2.close()
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("COMPARISON RESULTS")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
if self.warnings:
|
||||||
|
print(f"\nWarnings ({len(self.warnings)}):")
|
||||||
|
for warning in self.warnings[:20]:
|
||||||
|
print(f" - {warning}")
|
||||||
|
if len(self.warnings) > 20:
|
||||||
|
print(f" ... and {len(self.warnings) - 20} more warnings")
|
||||||
|
|
||||||
|
if self.mismatches:
|
||||||
|
print(f"\nMismatches found ({len(self.mismatches)}):")
|
||||||
|
for mismatch in self.mismatches[:30]:
|
||||||
|
print(f" - {mismatch}")
|
||||||
|
if len(self.mismatches) > 30:
|
||||||
|
print(f" ... and {len(self.mismatches) - 30} more mismatches")
|
||||||
|
print("\n❌ RESULT: Directories are DIFFERENT")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("\n✅ RESULT: Directories are IDENTICAL")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Compare two Simbox task output directories.",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
%(prog)s --dir1 output/run1 --dir2 output/run2
|
||||||
|
%(prog)s --dir1 output/run1 --dir2 output/run2 --tolerance 1e-5
|
||||||
|
%(prog)s --dir1 output/run1 --dir2 output/run2 --image-psnr 40.0
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument("--dir1", type=str, required=True, help="Path to the first output directory")
|
||||||
|
parser.add_argument("--dir2", type=str, required=True, help="Path to the second output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tolerance",
|
||||||
|
type=float,
|
||||||
|
default=1e-6,
|
||||||
|
help="Numerical tolerance for floating point comparisons (default: 1e-6)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image-psnr",
|
||||||
|
type=float,
|
||||||
|
default=37.0,
|
||||||
|
help="PSNR threshold (dB) for considering images as matching (default: 37.0)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
comparator = SimboxComparator(args.dir1, args.dir2, tolerance=args.tolerance, image_psnr_threshold=args.image_psnr)
|
||||||
|
|
||||||
|
success = comparator.compare()
|
||||||
|
exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
tests/integration/simbox/__init__.py
Normal file
0
tests/integration/simbox/__init__.py
Normal file
30
tests/integration/simbox/runners/simbox_pipeline_runner.py
Normal file
30
tests/integration/simbox/runners/simbox_pipeline_runner.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
simbox pipeline runner for Isaac Sim subprocess execution.
|
||||||
|
This script runs in the Isaac Sim environment and is called by the main test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.integration.base.test_harness import IntegrationTestHarness
|
||||||
|
|
||||||
|
# Add paths to sys.path
|
||||||
|
sys.path.append("./")
|
||||||
|
sys.path.append("./data_engine")
|
||||||
|
sys.path.append("./tests/integration")
|
||||||
|
|
||||||
|
|
||||||
|
def run_simbox_pipeline():
|
||||||
|
"""
|
||||||
|
Run the simbox pipeline test in Isaac Sim environment.
|
||||||
|
"""
|
||||||
|
harness = IntegrationTestHarness(config_path="configs/simbox/de_pipe_template.yaml", seed=42, random_num=1)
|
||||||
|
|
||||||
|
# Run the pipeline and validate output
|
||||||
|
harness.run_test_end_to_end(min_output_files=6)
|
||||||
|
|
||||||
|
print("✓ simbox pipeline test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_simbox_pipeline()
|
||||||
29
tests/integration/simbox/runners/simbox_plan_runner.py
Normal file
29
tests/integration/simbox/runners/simbox_plan_runner.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
simbox plan runner for Isaac Sim subprocess execution.
|
||||||
|
This script runs in the Isaac Sim environment and is called by the main test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.integration.base.test_harness import IntegrationTestHarness
|
||||||
|
|
||||||
|
# Add paths to sys.path
|
||||||
|
sys.path.append("./")
|
||||||
|
sys.path.append("./data_engine")
|
||||||
|
sys.path.append("./tests/integration")
|
||||||
|
|
||||||
|
|
||||||
|
def run_simbox_plan():
|
||||||
|
"""
|
||||||
|
Run the simbox plan test in Isaac Sim environment.
|
||||||
|
"""
|
||||||
|
harness = IntegrationTestHarness(config_path="configs/simbox/de_plan_template.yaml", seed=42, random_num=1)
|
||||||
|
|
||||||
|
# Run the pipeline and validate output
|
||||||
|
harness.run_test_end_to_end(min_output_files=1)
|
||||||
|
|
||||||
|
print("✓ simbox plan test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_simbox_plan()
|
||||||
29
tests/integration/simbox/runners/simbox_render_runner.py
Normal file
29
tests/integration/simbox/runners/simbox_render_runner.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
simbox render runner for Isaac Sim subprocess execution.
|
||||||
|
This script runs in the Isaac Sim environment and is called by the main test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.integration.base.test_harness import IntegrationTestHarness
|
||||||
|
|
||||||
|
# Add paths to sys.path
|
||||||
|
sys.path.append("./")
|
||||||
|
sys.path.append("./data_engine")
|
||||||
|
sys.path.append("./tests/integration")
|
||||||
|
|
||||||
|
|
||||||
|
def run_simbox_render():
|
||||||
|
"""
|
||||||
|
Run the simbox render test in Isaac Sim environment.
|
||||||
|
"""
|
||||||
|
harness = IntegrationTestHarness(config_path="configs/simbox/de_render_template.yaml", seed=42)
|
||||||
|
|
||||||
|
# Run the pipeline and validate output
|
||||||
|
harness.run_test_end_to_end(min_output_files=1)
|
||||||
|
|
||||||
|
print("✓ simbox render test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_simbox_render()
|
||||||
104
tests/integration/simbox/test_simbox.py
Normal file
104
tests/integration/simbox/test_simbox.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
simbox tests that run in Isaac Sim environment using subprocess wrappers.
|
||||||
|
Migrated from original test files to use the new IntegrationTestHarness framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.integration.base.test_harness import IntegrationTestHarness
|
||||||
|
|
||||||
|
# Add path for proper imports
|
||||||
|
sys.path.append("./")
|
||||||
|
sys.path.append("./data_engine")
|
||||||
|
sys.path.append("./tests/integration")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simbox_pipeline():
|
||||||
|
"""
|
||||||
|
Test simbox pipeline by running it in Isaac Sim subprocess.
|
||||||
|
This test uses a subprocess wrapper to handle Isaac Sim process separation.
|
||||||
|
"""
|
||||||
|
harness = IntegrationTestHarness(config_path="configs/simbox/de_pipe_template.yaml", seed=42, random_num=1)
|
||||||
|
|
||||||
|
# Run in subprocess using Isaac Sim Python interpreter
|
||||||
|
result = harness.run_data_engine_subprocess(
|
||||||
|
runner_script="tests/integration/simbox/runners/simbox_pipeline_runner.py",
|
||||||
|
interpreter="/isaac-sim/python.sh",
|
||||||
|
timeout=1800, # 30 minutes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subprocess completed successfully
|
||||||
|
assert result.returncode == 0, f"simbox pipeline test failed with return code: {result.returncode}"
|
||||||
|
|
||||||
|
# Validate that output was generated
|
||||||
|
# harness.validate_output_generated(min_files=6)
|
||||||
|
|
||||||
|
print("✓ simbox pipeline test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simbox_plan():
|
||||||
|
"""
|
||||||
|
Test simbox plan generation by running it in Isaac Sim subprocess.
|
||||||
|
"""
|
||||||
|
harness = IntegrationTestHarness(
|
||||||
|
config_path="configs/simbox/de_plan_template.yaml",
|
||||||
|
seed=42,
|
||||||
|
random_num=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run in subprocess using Isaac Sim Python interpreter
|
||||||
|
result = harness.run_data_engine_subprocess(
|
||||||
|
runner_script="tests/integration/simbox/runners/simbox_plan_runner.py",
|
||||||
|
interpreter="/isaac-sim/python.sh",
|
||||||
|
timeout=1800, # 30 minutes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subprocess completed successfully
|
||||||
|
assert result.returncode == 0, f"simbox plan test failed with return code: {result.returncode}"
|
||||||
|
|
||||||
|
# Validate that output was generated
|
||||||
|
# harness.validate_output_generated(min_files=1)
|
||||||
|
|
||||||
|
print("✓ simbox plan test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simbox_render():
|
||||||
|
"""
|
||||||
|
Test simbox render by running it in Isaac Sim subprocess.
|
||||||
|
"""
|
||||||
|
harness = IntegrationTestHarness(
|
||||||
|
config_path="configs/simbox/de_render_template.yaml",
|
||||||
|
# config_path="tests/integration/simbox/configs/simbox_test_render_configs.yaml",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_dir = "/shared/smartbot_new/zhangyuchang/CI/manip/simbox/simbox_render_ci"
|
||||||
|
# Run in subprocess using Isaac Sim Python interpreter
|
||||||
|
result = harness.run_data_engine_subprocess(
|
||||||
|
runner_script="tests/integration/simbox/runners/simbox_render_runner.py",
|
||||||
|
interpreter="/isaac-sim/python.sh",
|
||||||
|
timeout=1800, # 30 minutes
|
||||||
|
compare_output=True,
|
||||||
|
reference_dir=reference_dir,
|
||||||
|
comparator="simbox",
|
||||||
|
comparator_args={
|
||||||
|
# Use defaults; override here if needed
|
||||||
|
# "tolerance": 1e-6,
|
||||||
|
# "image_psnr": 37.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify subprocess completed successfully
|
||||||
|
assert result.returncode == 0, f"simbox render test failed with return code: {result.returncode}"
|
||||||
|
|
||||||
|
# Validate that output was generated
|
||||||
|
# harness.validate_output_generated(min_files=1)
|
||||||
|
|
||||||
|
print("✓ simbox render test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run all tests when script is executed directly
|
||||||
|
test_simbox_plan()
|
||||||
|
test_simbox_render()
|
||||||
|
test_simbox_pipeline()
|
||||||
127
tests/run_tests.sh
Executable file
127
tests/run_tests.sh
Executable file
@@ -0,0 +1,127 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Integration test runner for SimBox DataEngine
|
||||||
|
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
CYAN='\033[0;36m'
|
||||||
|
BOLD='\033[1m'
|
||||||
|
NC='\033[0m'
|
||||||
|
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
LOG_FILE="tests/test_results_${TIMESTAMP}.log"
|
||||||
|
SUMMARY_FILE="tests/test_summary_${TIMESTAMP}.txt"
|
||||||
|
TEMP_LOG="tests/temp_test_output.log"
|
||||||
|
|
||||||
|
declare -a TEST_SUITES=()
|
||||||
|
TEST_SUITES+=("SimBox Tests (Isaac Sim):3:/isaac-sim/python.sh tests/integration/simbox/test_simbox.py")
|
||||||
|
|
||||||
|
TOTAL_SUITES=${#TEST_SUITES[@]}
|
||||||
|
|
||||||
|
echo "Starting SimBox DataEngine Integration Tests..."
|
||||||
|
echo "=============================================="
|
||||||
|
echo -e "${BOLD}${CYAN}TEST EXECUTION PLAN:${NC}"
|
||||||
|
echo -e " ${BOLD}Total Test Suites: ${TOTAL_SUITES}${NC}"
|
||||||
|
echo -e " ${BOLD}SimBox Scenarios Covered:${NC}"
|
||||||
|
echo -e " - Pipeline: Full end-to-end workflow"
|
||||||
|
echo -e " - Plan: Trajectory planning generation"
|
||||||
|
echo -e " - Render: Scene rendering with validation"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo -e "${BLUE}Detailed logs: $LOG_FILE${NC}"
|
||||||
|
echo "=============================================="
|
||||||
|
|
||||||
|
TOTAL_TEST_SUITES=0
|
||||||
|
PASSED_TEST_SUITES=0
|
||||||
|
FAILED_TEST_SUITES=0
|
||||||
|
|
||||||
|
run_test_suite() {
|
||||||
|
local suite_name="$1"
|
||||||
|
local expected_sessions="$2"
|
||||||
|
local test_command="$3"
|
||||||
|
local current_suite=$4
|
||||||
|
|
||||||
|
TOTAL_TEST_SUITES=$((TOTAL_TEST_SUITES + 1))
|
||||||
|
|
||||||
|
echo -e "${BOLD}${BLUE}[$current_suite/$TOTAL_SUITES] Starting: $suite_name${NC}"
|
||||||
|
echo -e " ${CYAN}-> Running ${expected_sessions} SimBox tests with Isaac Sim${NC}"
|
||||||
|
|
||||||
|
echo "Test Suite: $suite_name" >> "$LOG_FILE"
|
||||||
|
echo "Expected Test Functions: $expected_sessions" >> "$LOG_FILE"
|
||||||
|
echo "Command: $test_command" >> "$LOG_FILE"
|
||||||
|
echo "Started at: $(date)" >> "$LOG_FILE"
|
||||||
|
echo "----------------------------------------" >> "$LOG_FILE"
|
||||||
|
|
||||||
|
echo -e "${CYAN}Executing: $test_command${NC}"
|
||||||
|
eval "$test_command" > "$TEMP_LOG" 2>&1
|
||||||
|
local exit_status=$?
|
||||||
|
|
||||||
|
cat "$TEMP_LOG" >> "$LOG_FILE"
|
||||||
|
echo "" >> "$LOG_FILE"
|
||||||
|
|
||||||
|
if [ $exit_status -eq 0 ]; then
|
||||||
|
echo -e "${GREEN}✓ [$current_suite/$TOTAL_SUITES] PASSED: $suite_name${NC}"
|
||||||
|
PASSED_TEST_SUITES=$((PASSED_TEST_SUITES + 1))
|
||||||
|
else
|
||||||
|
echo -e "${RED}✗ [$current_suite/$TOTAL_SUITES] FAILED: $suite_name${NC}"
|
||||||
|
FAILED_TEST_SUITES=$((FAILED_TEST_SUITES + 1))
|
||||||
|
echo -e "${YELLOW} Check $LOG_FILE for error details${NC}"
|
||||||
|
echo -e "${YELLOW}Recent output:${NC}"
|
||||||
|
tail -20 "$TEMP_LOG"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
rm -f "$TEMP_LOG"
|
||||||
|
}
|
||||||
|
|
||||||
|
> "$SUMMARY_FILE"
|
||||||
|
> "$LOG_FILE"
|
||||||
|
|
||||||
|
echo "Test execution started at: $(date)" > "$SUMMARY_FILE"
|
||||||
|
echo "Planned suites: $TOTAL_SUITES" >> "$SUMMARY_FILE"
|
||||||
|
echo "=======================================" >> "$SUMMARY_FILE"
|
||||||
|
|
||||||
|
echo "Test execution started at: $(date)" >> "$LOG_FILE"
|
||||||
|
echo "================================================" >> "$LOG_FILE"
|
||||||
|
|
||||||
|
for i in "${!TEST_SUITES[@]}"; do
|
||||||
|
suite_info="${TEST_SUITES[$i]}"
|
||||||
|
suite_name=$(echo "$suite_info" | cut -d':' -f1)
|
||||||
|
expected_sessions=$(echo "$suite_info" | cut -d':' -f2)
|
||||||
|
test_command=$(echo "$suite_info" | cut -d':' -f3-)
|
||||||
|
current_suite=$((i + 1))
|
||||||
|
|
||||||
|
run_test_suite "$suite_name" "$expected_sessions" "$test_command" $current_suite
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "=============================================="
|
||||||
|
echo "TEST EXECUTION SUMMARY"
|
||||||
|
echo "=============================================="
|
||||||
|
echo -e "${CYAN}Test Suites:${NC}"
|
||||||
|
echo -e " Total: $TOTAL_TEST_SUITES"
|
||||||
|
echo -e " ${GREEN}Passed: $PASSED_TEST_SUITES${NC}"
|
||||||
|
echo -e " ${RED}Failed: $FAILED_TEST_SUITES${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "" >> "$SUMMARY_FILE"
|
||||||
|
echo "FINAL SUMMARY:" >> "$SUMMARY_FILE"
|
||||||
|
echo "Suites - Total: $TOTAL_TEST_SUITES, Passed: $PASSED_TEST_SUITES, Failed: $FAILED_TEST_SUITES" >> "$SUMMARY_FILE"
|
||||||
|
|
||||||
|
if [ $FAILED_TEST_SUITES -eq 0 ]; then
|
||||||
|
echo -e "${GREEN}ALL TESTS PASSED${NC}"
|
||||||
|
echo "ALL TESTS PASSED at $(date)" >> "$SUMMARY_FILE"
|
||||||
|
else
|
||||||
|
echo -e "${RED}SOME TEST SUITES FAILED${NC}"
|
||||||
|
echo "SOME TEST SUITES FAILED at $(date)" >> "$SUMMARY_FILE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo -e "${BLUE}Results: $SUMMARY_FILE${NC}"
|
||||||
|
echo -e "${BLUE}Logs: $LOG_FILE${NC}"
|
||||||
|
|
||||||
|
if [ $FAILED_TEST_SUITES -eq 0 ]; then
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
7
workflows/__init__.py
Normal file
7
workflows/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# flake8: noqa: F401
|
||||||
|
# pylint: disable=W0611
|
||||||
|
def import_extensions(workflow_type):
|
||||||
|
if workflow_type == "SimBoxDualWorkFlow":
|
||||||
|
import workflows.simbox_dual_workflow
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported workflow type: {workflow_type}")
|
||||||
203
workflows/base.py
Normal file
203
workflows/base.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
|
class NimbusWorkFlow(ABC):
|
||||||
|
workflows = {}
|
||||||
|
|
||||||
|
# pylint: disable=W0613
|
||||||
|
def __init__(self, world, task_cfg_path: str, **kwargs):
|
||||||
|
"""Initialize the workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
world: The simulation world instance.
|
||||||
|
task_cfg_path (str): Path to the task configuration file.
|
||||||
|
Each workflow subclass is responsible for parsing this file.
|
||||||
|
**kwargs: Workflow-specific parameters.
|
||||||
|
Subclasses declare only the kwargs they need; unused ones are silently ignored.
|
||||||
|
"""
|
||||||
|
self.world = world
|
||||||
|
self.task_cfg_path = task_cfg_path
|
||||||
|
self.task_cfgs = self.parse_task_cfgs(task_cfg_path)
|
||||||
|
|
||||||
|
def init_task(self, index, need_preload: bool = True):
|
||||||
|
assert index < len(self.task_cfgs), "Index out of range for task configurations."
|
||||||
|
self.task_cfg = self.task_cfgs[index]
|
||||||
|
self.reset(need_preload)
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
new_wf = type(self).__new__(type(self))
|
||||||
|
new_wf.__dict__.update(self.__dict__)
|
||||||
|
|
||||||
|
if hasattr(self, "logger"):
|
||||||
|
new_wf.logger = deepcopy(self.logger)
|
||||||
|
|
||||||
|
if hasattr(self, "recoder"):
|
||||||
|
new_wf.recoder = deepcopy(self.recoder)
|
||||||
|
|
||||||
|
return new_wf
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_task_cfgs(self, task_cfg_path) -> list:
|
||||||
|
"""
|
||||||
|
Parse the task configuration file.
|
||||||
|
Args:
|
||||||
|
task_cfg_path (str): Path to the task configuration file.
|
||||||
|
Returns:
|
||||||
|
list: List of task configurations.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_task_name(self) -> str:
|
||||||
|
"""Get the name of the current task.
|
||||||
|
Returns:
|
||||||
|
str: name of the current task
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self, need_preload):
|
||||||
|
"""Reset the environment to the initial state of the current task.
|
||||||
|
Args:
|
||||||
|
need_preload (bool): Whether to preload objects in the environment. Defaults to True.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def randomization(self, layout_path=None) -> bool:
|
||||||
|
"""Randomize the environment layout in one task.
|
||||||
|
Args:
|
||||||
|
layout_path (str, optional): Path to the layout file. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
bool: True if randomization is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_seq(self) -> list:
|
||||||
|
"""Generate a sequence of states for the current task.
|
||||||
|
Returns:
|
||||||
|
list: Sequence of states which be replayed for the current task.
|
||||||
|
If the sequence is not generated, return an empty list.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def seq_replay(self, sequence: list) -> int:
|
||||||
|
"""Replay the sequence and generate observations.
|
||||||
|
Args:
|
||||||
|
sequence (list): Sequence of states to be replayed.
|
||||||
|
Returns:
|
||||||
|
int: Length of the replayed sequence.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, save_path: str) -> int:
|
||||||
|
"""Save the all information.
|
||||||
|
Args:
|
||||||
|
save_path (str): Path to save the observations.
|
||||||
|
Returns:
|
||||||
|
int: Length of the saved observations.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# plan mode
|
||||||
|
def save_seq(self, save_path: str) -> int:
|
||||||
|
"""Save the generated sequence without observations.
|
||||||
|
Args:
|
||||||
|
save_path (str): Path to save the sequence.
|
||||||
|
Returns:
|
||||||
|
int: Length of the saved sequence.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
# render mode
|
||||||
|
def recover_seq(self, seq_path: str) -> list:
|
||||||
|
"""Recover sequence from a sequence file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_path (str): Path to the sequence file.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
# plan with render mode
|
||||||
|
def generate_seq_with_obs(self) -> int:
|
||||||
|
"""Generate a sequence with observation for the current task.
|
||||||
|
(For debug or future RL)
|
||||||
|
Returns:
|
||||||
|
int: Length of the generated sequence.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
# pipeline mode
|
||||||
|
def dump_plan_info(self) -> bytes:
|
||||||
|
"""Dump the layout and sequence plan information of the current task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Serialized plan information including layout and sequence data.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
# pipeline mode
|
||||||
|
def dedump_plan_info(self, ser_obj: bytes) -> object:
|
||||||
|
"""Deserialize the layout and plan information of the current task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ser_obj (bytes): Serialized plan information generated from dump_plan_info().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: Deserialized layout and sequence information.
|
||||||
|
This will be used as input for randomization_from_mem() and recover_seq_from_mem().
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
# pipeline mode
|
||||||
|
def randomization_from_mem(self, data: object) -> bool:
|
||||||
|
"""Perform randomization using in-memory plan data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (object): Deserialized layout and sequence information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if randomization succeeds, False otherwise.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
# pipeline mode
|
||||||
|
def recover_seq_from_mem(self, data: object) -> list:
|
||||||
|
"""Recover sequence from in-memory plan data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (object): Deserialized layout and sequence information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Recovered sequence of states.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be implemented in the subclass.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
"""
|
||||||
|
Register a workflow with its name(decorator).
|
||||||
|
Args:
|
||||||
|
name(str): name of the workflow
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(wfs_class):
|
||||||
|
cls.workflows[name] = wfs_class
|
||||||
|
|
||||||
|
@wraps(wfs_class)
|
||||||
|
def wrapped_function(*args, **kwargs):
|
||||||
|
return wfs_class(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped_function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def create_workflow(workflow_type: str, world, task_cfg_path: str, **kwargs):
|
||||||
|
wf_cls = NimbusWorkFlow.workflows[workflow_type]
|
||||||
|
return wf_cls(world, task_cfg_path, **kwargs)
|
||||||
118
workflows/simbox/core/cameras/README.md
Normal file
118
workflows/simbox/core/cameras/README.md
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
# Cameras
|
||||||
|
|
||||||
|
Template-based cameras for simbox tasks. All cameras currently use a single generic implementation, `CustomCamera`, which is configured entirely from the task YAML.
|
||||||
|
|
||||||
|
## Available cameras
|
||||||
|
|
||||||
|
| Camera class | Notes |
|
||||||
|
|-----------------|-------|
|
||||||
|
| `CustomCamera` | Generic pinhole RGB-D camera with configurable intrinsics and pose. |
|
||||||
|
|
||||||
|
Importing `CustomCamera` in your task (e.g. `banana.py`) is enough to register it via `@register_camera`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Customizing a camera configuration
|
||||||
|
|
||||||
|
Camera behavior is controlled by the config (`cfg`) passed into `CustomCamera.__init__` in `banana.py`. You typically edit the YAML under `configs/simbox/...`.
|
||||||
|
|
||||||
|
### 1. Top-level camera fields
|
||||||
|
|
||||||
|
Each camera entry in the YAML should provide:
|
||||||
|
|
||||||
|
- **`name`**: Unique camera name (string). Used for prim paths and as the key in `task.cameras`.
|
||||||
|
- **`parent`**: Optional prim path (under the task root) that the camera mount is attached to. Empty string (`""`) means no specific parent.
|
||||||
|
- **`translation`**: Initial camera translation in world or parent frame, as a list of three floats `[x, y, z]` (meters).
|
||||||
|
- **`orientation`**: Initial camera orientation as a quaternion `[w, x, y, z]`.
|
||||||
|
- **`camera_axes`**: Axes convention for `set_local_pose` (e.g. `[1.0, 0.0, 0.0]` etc. – follow existing configs).
|
||||||
|
|
||||||
|
These values are used in `banana.py` when calling:
|
||||||
|
|
||||||
|
```python
|
||||||
|
camera.set_local_pose(
|
||||||
|
translation=cfg["translation"],
|
||||||
|
orientation=cfg["orientation"],
|
||||||
|
camera_axes=cfg["camera_axes"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Required `params` fields
|
||||||
|
|
||||||
|
Inside each camera config there is a `params` dict that controls the optics and intrinsics. `CustomCamera` expects:
|
||||||
|
|
||||||
|
- **`pixel_size`** (`float`, microns)
|
||||||
|
Physical pixel size on the sensor. Used to compute horizontal/vertical aperture and focal length.
|
||||||
|
|
||||||
|
- **`f_number`** (`float`)
|
||||||
|
Lens f-number. Used in `set_lens_aperture(f_number * 100.0)`.
|
||||||
|
|
||||||
|
- **`focus_distance`** (`float`, meters)
|
||||||
|
Focus distance passed to `set_focus_distance`.
|
||||||
|
|
||||||
|
- **`camera_params`** (`[fx, fy, cx, cy]`)
|
||||||
|
Intrinsic matrix parameters in pixel units:
|
||||||
|
- `fx`, `fy`: focal lengths in x/y (pixels)
|
||||||
|
- `cx`, `cy`: principal point (pixels)
|
||||||
|
|
||||||
|
- **`resolution_width`** (`int`)
|
||||||
|
Image width in pixels.
|
||||||
|
|
||||||
|
- **`resolution_height`** (`int`)
|
||||||
|
Image height in pixels.
|
||||||
|
|
||||||
|
Optional:
|
||||||
|
|
||||||
|
- **`output_mode`** (`"rgb"` or `"diffuse_albedo"`, default `"rgb"`)
|
||||||
|
Controls which color source is used in `get_observations()`.
|
||||||
|
|
||||||
|
### 3. How the parameters are used in `CustomCamera`
|
||||||
|
|
||||||
|
Given `cfg["params"]`, `CustomCamera` does the following:
|
||||||
|
|
||||||
|
- Computes the camera apertures and focal length:
|
||||||
|
- `horizontal_aperture = pixel_size * 1e-3 * width`
|
||||||
|
- `vertical_aperture = pixel_size * 1e-3 * height`
|
||||||
|
- `focal_length_x = fx * pixel_size * 1e-3`
|
||||||
|
- `focal_length_y = fy * pixel_size * 1e-3`
|
||||||
|
- `focal_length = (focal_length_x + focal_length_y) / 2`
|
||||||
|
- Sets optical parameters:
|
||||||
|
- `set_focal_length(focal_length / 10.0)`
|
||||||
|
- `set_focus_distance(focus_distance)`
|
||||||
|
- `set_lens_aperture(f_number * 100.0)`
|
||||||
|
- `set_horizontal_aperture(horizontal_aperture / 10.0)`
|
||||||
|
- `set_vertical_aperture(vertical_aperture / 10.0)`
|
||||||
|
- `set_clipping_range(0.05, 1.0e5)`
|
||||||
|
- `set_projection_type("pinhole")`
|
||||||
|
- Recomputes intrinsic matrix `K` on the fly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
fx = width * self.get_focal_length() / self.get_horizontal_aperture()
|
||||||
|
fy = height * self.get_focal_length() / self.get_vertical_aperture()
|
||||||
|
self.is_camera_matrix = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Outputs from `get_observations()`
|
||||||
|
|
||||||
|
`CustomCamera.get_observations()` returns a dict:
|
||||||
|
|
||||||
|
- **`color_image`**: RGB image (`H x W x 3`, float32), either from `get_rgba()` or `DiffuseAlbedo` depending on `output_mode`.
|
||||||
|
- **`depth_image`**: Depth map from `get_depth()` (same resolution as color).
|
||||||
|
- **`camera2env_pose`**: 4x4 transform from camera to environment, computed from USD prims.
|
||||||
|
- **`camera_params`**: 3x3 intrinsic matrix `K` as a Python list.
|
||||||
|
|
||||||
|
These are the values consumed by tasks (e.g. `banana.py`) for perception and planning.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary checklist for a new camera
|
||||||
|
|
||||||
|
To add or tweak a camera in a task YAML:
|
||||||
|
|
||||||
|
1. **Choose a `name`** and, optionally, a `parent` prim under the task root.
|
||||||
|
2. **Set pose**: `translation`, `orientation` (quaternion `[w, x, y, z]`), and `camera_axes`.
|
||||||
|
3. Under `params`, provide:
|
||||||
|
- `pixel_size`, `f_number`, `focus_distance`
|
||||||
|
- `camera_params = [fx, fy, cx, cy]`
|
||||||
|
- `resolution_width`, `resolution_height`
|
||||||
|
- optional `output_mode` (`"rgb"` or `"diffuse_albedo"`).
|
||||||
|
4. Ensure your task (e.g. `banana.py`) constructs `CustomCamera` with this `cfg` (this is already wired up in the current code).
|
||||||
21
workflows/simbox/core/cameras/__init__.py
Normal file
21
workflows/simbox/core/cameras/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Camera module initialization."""
|
||||||
|
|
||||||
|
from core.cameras.base_camera import CAMERA_DICT
|
||||||
|
|
||||||
|
from .custom_camera import CustomCamera
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CustomCamera",
|
||||||
|
"get_camera_cls",
|
||||||
|
"get_camera_dict",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_camera_cls(category_name):
|
||||||
|
"""Get camera class by category name."""
|
||||||
|
return CAMERA_DICT[category_name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_camera_dict():
|
||||||
|
"""Get camera dictionary."""
|
||||||
|
return CAMERA_DICT
|
||||||
9
workflows/simbox/core/cameras/base_camera.py
Normal file
9
workflows/simbox/core/cameras/base_camera.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
CAMERA_DICT = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_camera(target_class):
|
||||||
|
# key = "_".join(re.sub(r"([A-Z0-9])", r" \1", target_class.__name__).split()).lower()
|
||||||
|
key = target_class.__name__
|
||||||
|
assert key not in CAMERA_DICT
|
||||||
|
CAMERA_DICT[key] = target_class
|
||||||
|
return target_class
|
||||||
163
workflows/simbox/core/cameras/custom_camera.py
Normal file
163
workflows/simbox/core/cameras/custom_camera.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import numpy as np
|
||||||
|
import omni.replicator.core as rep
|
||||||
|
from core.cameras.base_camera import register_camera
|
||||||
|
from omni.isaac.core.prims import XFormPrim
|
||||||
|
from omni.isaac.core.utils.prims import get_prim_at_path
|
||||||
|
from omni.isaac.core.utils.transformations import (
|
||||||
|
get_relative_transform,
|
||||||
|
pose_from_tf_matrix,
|
||||||
|
)
|
||||||
|
from omni.isaac.sensor import Camera
|
||||||
|
|
||||||
|
|
||||||
|
@register_camera
|
||||||
|
class CustomCamera(Camera):
|
||||||
|
"""Generic pinhole RGB-D camera used in simbox tasks."""
|
||||||
|
|
||||||
|
def __init__(self, cfg, prim_path, root_prim_path, reference_path, name, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cfg: Config dict with required keys:
|
||||||
|
- params: Dict containing:
|
||||||
|
- pixel_size: Pixel size in microns
|
||||||
|
- f_number: F-number
|
||||||
|
- focus_distance: Focus distance in meters
|
||||||
|
- camera_params: [fx, fy, cx, cy] camera intrinsics
|
||||||
|
- resolution_width: Image width
|
||||||
|
- resolution_height: Image height
|
||||||
|
- output_mode (optional): "rgb" or "diffuse_albedo"
|
||||||
|
prim_path: Camera prim path in USD stage
|
||||||
|
root_prim_path: Root prim path in USD stage
|
||||||
|
reference_path: Reference prim path for camera mounting
|
||||||
|
name: Camera name
|
||||||
|
"""
|
||||||
|
# ===== Initialize camera =====
|
||||||
|
super().__init__(
|
||||||
|
prim_path=prim_path,
|
||||||
|
name=name,
|
||||||
|
resolution=(cfg["params"]["resolution_width"], cfg["params"]["resolution_height"]),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.initialize()
|
||||||
|
self.add_motion_vectors_to_frame()
|
||||||
|
self.add_semantic_segmentation_to_frame()
|
||||||
|
self.add_distance_to_image_plane_to_frame()
|
||||||
|
|
||||||
|
# ===== From cfg =====
|
||||||
|
pixel_size = cfg["params"].get("pixel_size")
|
||||||
|
f_number = cfg["params"].get("f_number")
|
||||||
|
focus_distance = cfg["params"].get("focus_distance")
|
||||||
|
fx, fy, cx, cy = cfg["params"].get("camera_params")
|
||||||
|
width = cfg["params"].get("resolution_width")
|
||||||
|
height = cfg["params"].get("resolution_height")
|
||||||
|
self.output_mode = cfg.get("output_mode", "rgb")
|
||||||
|
|
||||||
|
# ===== Compute and set camera parameters =====
|
||||||
|
horizontal_aperture = pixel_size * 1e-3 * width
|
||||||
|
vertical_aperture = pixel_size * 1e-3 * height
|
||||||
|
focal_length_x = fx * pixel_size * 1e-3
|
||||||
|
focal_length_y = fy * pixel_size * 1e-3
|
||||||
|
focal_length = (focal_length_x + focal_length_y) / 2
|
||||||
|
|
||||||
|
self.set_focal_length(focal_length / 10.0)
|
||||||
|
self.set_focus_distance(focus_distance)
|
||||||
|
self.set_lens_aperture(f_number * 100.0)
|
||||||
|
self.set_horizontal_aperture(horizontal_aperture / 10.0)
|
||||||
|
self.set_vertical_aperture(vertical_aperture / 10.0)
|
||||||
|
self.set_clipping_range(0.05, 1.0e5)
|
||||||
|
self.set_projection_type("pinhole")
|
||||||
|
|
||||||
|
fx = width * self.get_focal_length() / self.get_horizontal_aperture()
|
||||||
|
fy = height * self.get_focal_length() / self.get_vertical_aperture()
|
||||||
|
self.is_camera_matrix = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])
|
||||||
|
|
||||||
|
self.reference_path = reference_path
|
||||||
|
self.root_prim_path = root_prim_path
|
||||||
|
self.parent_camera_prim_path = str(self.prim.GetParent().GetPath())
|
||||||
|
self.parent_camera_xform = XFormPrim(self.parent_camera_prim_path)
|
||||||
|
|
||||||
|
if self.output_mode == "diffuse_albedo":
|
||||||
|
self.add_diffuse_albedo_to_frame()
|
||||||
|
|
||||||
|
def add_diffuse_albedo_to_frame(self) -> None:
|
||||||
|
"""Attach the diffuse_albedo annotator to this camera."""
|
||||||
|
if "DiffuseAlbedo" not in self._custom_annotators:
|
||||||
|
self._custom_annotators["DiffuseAlbedo"] = rep.AnnotatorRegistry.get_annotator("DiffuseAlbedo")
|
||||||
|
self._custom_annotators["DiffuseAlbedo"].attach([self._render_product_path])
|
||||||
|
self._current_frame["DiffuseAlbedo"] = None
|
||||||
|
|
||||||
|
def remove_diffuse_albedo_from_frame(self) -> None:
|
||||||
|
if self._custom_annotators["DiffuseAlbedo"] is not None:
|
||||||
|
self._custom_annotators["DiffuseAlbedo"].detach([self._render_product_path])
|
||||||
|
self._custom_annotators["DiffuseAlbedo"] = None
|
||||||
|
self._current_frame.pop("DiffuseAlbedo", None)
|
||||||
|
|
||||||
|
def add_specular_albedo_to_frame(self) -> None:
|
||||||
|
"""Attach the specular_albedo annotator to this camera."""
|
||||||
|
if self._custom_annotators["SpecularAlbedo"] is None:
|
||||||
|
self._custom_annotators["SpecularAlbedo"] = rep.AnnotatorRegistry.get_annotator("SpecularAlbedo")
|
||||||
|
self._custom_annotators["SpecularAlbedo"].attach([self._render_product_path])
|
||||||
|
self._current_frame["SpecularAlbedo"] = None
|
||||||
|
|
||||||
|
def remove_specular_albedo_from_frame(self) -> None:
|
||||||
|
if self._custom_annotators["SpecularAlbedo"] is not None:
|
||||||
|
self._custom_annotators["SpecularAlbedo"].detach([self._render_product_path])
|
||||||
|
self._custom_annotators["SpecularAlbedo"] = None
|
||||||
|
self._current_frame.pop("SpecularAlbedo", None)
|
||||||
|
|
||||||
|
def add_direct_diffuse_to_frame(self) -> None:
|
||||||
|
"""Attach the direct_diffuse annotator to this camera."""
|
||||||
|
if self._custom_annotators["DirectDiffuse"] is None:
|
||||||
|
self._custom_annotators["DirectDiffuse"] = rep.AnnotatorRegistry.get_annotator("DirectDiffuse")
|
||||||
|
self._custom_annotators["DirectDiffuse"].attach([self._render_product_path])
|
||||||
|
self._current_frame["DirectDiffuse"] = None
|
||||||
|
|
||||||
|
def remove_direct_diffuse_from_frame(self) -> None:
|
||||||
|
if self._custom_annotators["DirectDiffuse"] is not None:
|
||||||
|
self._custom_annotators["DirectDiffuse"].detach([self._render_product_path])
|
||||||
|
self._custom_annotators["DirectDiffuse"] = None
|
||||||
|
self._current_frame.pop("DirectDiffuse", None)
|
||||||
|
|
||||||
|
def add_indirect_diffuse_to_frame(self) -> None:
|
||||||
|
"""Attach the indirect_diffuse annotator to this camera."""
|
||||||
|
if self._custom_annotators["IndirectDiffuse"] is None:
|
||||||
|
self._custom_annotators["IndirectDiffuse"] = rep.AnnotatorRegistry.get_annotator("IndirectDiffuse")
|
||||||
|
self._custom_annotators["IndirectDiffuse"].attach([self._render_product_path])
|
||||||
|
self._current_frame["IndirectDiffuse"] = None
|
||||||
|
|
||||||
|
def remove_indirect_diffuse_from_frame(self) -> None:
|
||||||
|
if self._custom_annotators["IndirectDiffuse"] is not None:
|
||||||
|
self._custom_annotators["IndirectDiffuse"].detach([self._render_product_path])
|
||||||
|
self._custom_annotators["IndirectDiffuse"] = None
|
||||||
|
self._current_frame.pop("IndirectDiffuse", None)
|
||||||
|
|
||||||
|
def get_observations(self):
|
||||||
|
if self.reference_path:
|
||||||
|
camera_mount2env_pose = get_relative_transform(
|
||||||
|
get_prim_at_path(self.reference_path), get_prim_at_path(self.root_prim_path)
|
||||||
|
)
|
||||||
|
camera_mount2env_pose = pose_from_tf_matrix(camera_mount2env_pose)
|
||||||
|
self.parent_camera_xform.set_local_pose(
|
||||||
|
translation=camera_mount2env_pose[0],
|
||||||
|
orientation=camera_mount2env_pose[1],
|
||||||
|
)
|
||||||
|
camera2env_pose = get_relative_transform(
|
||||||
|
get_prim_at_path(self.prim_path), get_prim_at_path(self.root_prim_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.output_mode == "rgb":
|
||||||
|
color_image = self.get_rgba()[..., :3]
|
||||||
|
elif self.output_mode == "diffuse_albedo":
|
||||||
|
color_image = self._custom_annotators["DiffuseAlbedo"].get_data()[..., :3]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
obs = {
|
||||||
|
"color_image": color_image,
|
||||||
|
"depth_image": self.get_depth(),
|
||||||
|
"camera2env_pose": camera2env_pose,
|
||||||
|
"camera_params": self.is_camera_matrix.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return obs
|
||||||
73
workflows/simbox/core/configs/arenas/azure_loong_arena.yaml
Normal file
73
workflows/simbox/core/configs/arenas/azure_loong_arena.yaml
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
name: azure_loong_arena
|
||||||
|
fixtures:
|
||||||
|
-
|
||||||
|
name: table
|
||||||
|
path: table0/instance.usd
|
||||||
|
target_class: GeometryObject
|
||||||
|
translation: [0.0, 0.0, 0.705]
|
||||||
|
# euler: [0.0, 0.0, 0.0]
|
||||||
|
# quaternion: [1.0, 0.0, 0.0, 0.0]
|
||||||
|
scale: [0.0015, 0.0015, 0.001]
|
||||||
|
texture:
|
||||||
|
texture_lib: "val2017"
|
||||||
|
apply_randomization: False
|
||||||
|
texture_id: 0
|
||||||
|
texture_scale: [0.001, 0.001]
|
||||||
|
-
|
||||||
|
name: floor
|
||||||
|
target_class: PlaneObject
|
||||||
|
size: [5.0, 5.0]
|
||||||
|
translation: [0, 0, 0]
|
||||||
|
texture:
|
||||||
|
texture_lib: "floor_textures"
|
||||||
|
apply_randomization: False
|
||||||
|
texture_id: 1
|
||||||
|
texture_scale: [1.0, 1.0]
|
||||||
|
# -
|
||||||
|
# name: background0
|
||||||
|
# target_class: PlaneObject
|
||||||
|
# size: [3.0, 5.0]
|
||||||
|
# translation: [-2, 0, 1]
|
||||||
|
# euler: [0.0, 90.0, 0.0]
|
||||||
|
# texture:
|
||||||
|
# texture_lib: "background_textures"
|
||||||
|
# apply_randomization: False
|
||||||
|
# texture_id: 1
|
||||||
|
# texture_scale: [1.0, 1.0]
|
||||||
|
# -
|
||||||
|
# name: background1
|
||||||
|
# target_class: PlaneObject
|
||||||
|
# size: [3.0, 5.0]
|
||||||
|
# translation: [2, 0, 1]
|
||||||
|
# euler: [0.0, 90.0, 0.0]
|
||||||
|
# texture:
|
||||||
|
# texture_lib: "background_textures"
|
||||||
|
# apply_randomization: False
|
||||||
|
# texture_id: 1
|
||||||
|
# texture_scale: [1.0, 1.0]
|
||||||
|
# -
|
||||||
|
# name: background2
|
||||||
|
# target_class: PlaneObject
|
||||||
|
# size: [5.0, 3.0]
|
||||||
|
# translation: [0, -2, 1]
|
||||||
|
# euler: [90.0, 0.0, 0.0]
|
||||||
|
# texture:
|
||||||
|
# texture_lib: "background_textures"
|
||||||
|
# apply_randomization: False
|
||||||
|
# texture_id: 1
|
||||||
|
# texture_scale: [1.0, 1.0]
|
||||||
|
# -
|
||||||
|
# name: background3
|
||||||
|
# target_class: PlaneObject
|
||||||
|
# size: [5.0, 3.0]
|
||||||
|
# translation: [0, 2, 1]
|
||||||
|
# euler: [90.0, 0.0, 0.0]
|
||||||
|
# texture:
|
||||||
|
# texture_lib: "background_textures"
|
||||||
|
# apply_randomization: False
|
||||||
|
# texture_id: 1
|
||||||
|
# texture_scale: [1.0, 1.0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# affordance region should be defined here
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user