Ruff is a code linting and formatting tool that have the ability to improve the code quality of our node and seems to be better than previously used black & pylint. This PR makes initial transition towards this tool.tags/v0.3.9-rc1
| @@ -310,15 +310,14 @@ jobs: | |||
| mv .venv/Scripts .venv/bin # venv is placed under `Scripts` on Windows | |||
| fi | |||
| source .venv/bin/activate | |||
| pip3 install maturin black pylint pytest | |||
| pip3 install maturin ruff pytest | |||
| maturin build -m apis/python/node/Cargo.toml | |||
| pip3 install target/wheels/* | |||
| dora new test_python_project --lang python --internal-create-with-path-dependencies | |||
| cd test_python_project | |||
| # Check Compliancy | |||
| black . --check | |||
| pylint --disable=C,R **/*.py | |||
| ruff check . | |||
| pip install -e ./*/ | |||
| pytest | |||
| @@ -86,7 +86,7 @@ jobs: | |||
| run: | | |||
| curl -sSL https://install.python-poetry.org | python3 - | |||
| echo "$HOME/.local/bin" >> $GITHUB_PATH | |||
| pip install black pylint pytest | |||
| pip install ruff pytest | |||
| - name: Set up Rust | |||
| if: runner.os == 'Linux' || github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && startsWith(github.ref, 'refs/tags/')) | |||
| @@ -43,8 +43,7 @@ else | |||
| if [ -f "$dir/pyproject.toml" ]; then | |||
| echo "Running linting and tests for Python project in $dir..." | |||
| pip install . | |||
| poetry run black --check . | |||
| poetry run pylint --disable=C,R --ignored-modules=cv2,pyrealsense2 **/*.py | |||
| ruff check . | |||
| poetry run pytest | |||
| fi | |||
| fi | |||
| @@ -10,16 +10,16 @@ pip install -e . | |||
| ## Contribution Guide | |||
| - Format with [black](https://github.com/psf/black): | |||
| - Format with [ruff](https://docs.astral.sh/ruff/): | |||
| ```bash | |||
| black . # Format | |||
| ruff check . --fix | |||
| ``` | |||
| - Lint with [pylint](https://github.com/pylint-dev/pylint): | |||
| - Lint with ruff: | |||
| ```bash | |||
| pylint --disable=C,R --ignored-modules=cv2 . # Lint | |||
| ruff check . | |||
| ``` | |||
| - Test with [pytest](https://github.com/pytest-dev/pytest) | |||
| @@ -12,13 +12,12 @@ packages = [{ include = "__node_name__" }] | |||
| [tool.poetry.dependencies] | |||
| dora-rs = "^0.3.6" | |||
| numpy = "< 2.0.0" | |||
| pyarrow = ">= 5.0.0" | |||
| pyarrow = ">= 15.0.0" | |||
| python = "^3.7" | |||
| [tool.poetry.dev-dependencies] | |||
| pytest = ">= 8.3.4" | |||
| pylint = ">= 3.3.2" | |||
| black = ">= 24.10" | |||
| pytest = ">= 6.3.4" | |||
| ruff = ">= 0.9.1" | |||
| [tool.poetry.scripts] | |||
| __node-name__ = "__node_name__.main:main" | |||
| @@ -77,9 +77,10 @@ fn create_custom_node( | |||
| .with_context(|| format!("failed to write `{}`", node_path.display()))?; | |||
| // tests/tests___node_name__.py | |||
| let node_path = root | |||
| .join("tests") | |||
| .join(format!("test_{}.py", name.replace(" ", "_"))); | |||
| let node_path = root.join("tests").join(format!( | |||
| "test_{}.py", | |||
| name.replace(" ", "_").replace("-", "_") | |||
| )); | |||
| let file = replace_space(_TEST_PY, &name); | |||
| fs::write(&node_path, file) | |||
| .with_context(|| format!("failed to write `{}`", node_path.display()))?; | |||
| @@ -90,8 +91,8 @@ fn create_custom_node( | |||
| ); | |||
| println!(" cd {}", Path::new(".").join(&root).display()); | |||
| println!(" pip install -e . # Install",); | |||
| println!(" black . # Format"); | |||
| println!(" pylint --disable=C,R . # Lint",); | |||
| println!(" ruff check . --fix # Format"); | |||
| println!(" ruff check . # Lint",); | |||
| println!(" pytest . # Test"); | |||
| Ok(()) | |||
| @@ -1,7 +1,6 @@ | |||
| from dora import Node | |||
| import numpy as np | |||
| import h5py | |||
| f = h5py.File("data/episode_0.hdf5", "r") | |||
| @@ -1,7 +1,6 @@ | |||
| from dora import Node | |||
| import numpy as np | |||
| import h5py | |||
| import os | |||
| @@ -1,6 +1,5 @@ | |||
| import os | |||
| import cv2 | |||
| import time | |||
| from dora import DoraStatus | |||
| from utils import LABELS | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -2,10 +2,10 @@ import os | |||
| os.environ["ARGOS_DEVICE_TYPE"] = "auto" | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import argostranslate.package | |||
| import argostranslate.translate | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from_code = os.getenv("SOURCE_LANGUAGE", "fr") | |||
| to_code = os.getenv("TARGET_LANGUAGE", "en") | |||
| @@ -15,8 +15,8 @@ argostranslate.package.update_package_index() | |||
| available_packages = argostranslate.package.get_available_packages() | |||
| package_to_install = next( | |||
| filter( | |||
| lambda x: x.from_code == from_code and x.to_code == to_code, available_packages | |||
| ) | |||
| lambda x: x.from_code == from_code and x.to_code == to_code, available_packages, | |||
| ), | |||
| ) | |||
| argostranslate.package.install_from_path(package_to_install.download()) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -47,7 +47,8 @@ def load_model(): | |||
| def load_model_mlx(): | |||
| from lightning_whisper_mlx import LightningWhisperMLX # noqa | |||
| # noqa: disable: import-error | |||
| from lightning_whisper_mlx import LightningWhisperMLX | |||
| whisper = LightningWhisperMLX(model="distil-large-v3", batch_size=12, quant=None) | |||
| return whisper | |||
| @@ -78,7 +79,8 @@ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): | |||
| if sum(1 for char in text if "\u4e00" <= char <= "\u9fff") / len(text) > 0.5: | |||
| # Chinese text processing | |||
| for repeat_length in range( | |||
| min_repeat_length, min(max_repeat_length, len(text) // 2) | |||
| min_repeat_length, | |||
| min(max_repeat_length, len(text) // 2), | |||
| ): | |||
| for i in range(len(text) - repeat_length * 2 + 1): | |||
| chunk1 = text[i : i + repeat_length] | |||
| @@ -90,7 +92,8 @@ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): | |||
| # Non-Chinese (space-separated) text processing | |||
| words = text.split() | |||
| for repeat_length in range( | |||
| min_repeat_length, min(max_repeat_length, len(words) // 2) | |||
| min_repeat_length, | |||
| min(max_repeat_length, len(words) // 2), | |||
| ): | |||
| for i in range(len(words) - repeat_length * 2 + 1): | |||
| chunk1 = " ".join(words[i : i + repeat_length]) | |||
| @@ -28,8 +28,7 @@ dora-distil-whisper = "dora_distil_whisper.main:main" | |||
| [tool.poetry.dev-dependencies] | |||
| pytest = ">= 6.3.4" | |||
| pylint = ">= 3.3.2" | |||
| black = ">= 22.10" | |||
| ruff = ">= 0.9.1" | |||
| [build-system] | |||
| requires = ["poetry-core>=1.8.0"] | |||
| @@ -2,7 +2,6 @@ import pytest | |||
| def test_import_main(): | |||
| from dora_distil_whisper.main import main | |||
| # Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow. | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,12 +1,12 @@ | |||
| import argparse | |||
| import os | |||
| from dora import Node | |||
| RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser(description="Simple arrow sender") | |||
| @@ -20,7 +20,7 @@ def main(): | |||
| args = parser.parse_args() | |||
| node = Node( | |||
| args.name | |||
| args.name, | |||
| ) # provide the name to connect to the dataflow if dynamic node | |||
| for event in node: | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,9 +1,10 @@ | |||
| import os | |||
| from dora import Node | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import torch | |||
| import torchvision.transforms as T | |||
| from dora import Node | |||
| from PIL import Image | |||
| from torchvision.transforms.functional import InterpolationMode | |||
| from transformers import AutoModel, AutoTokenizer | |||
| @@ -20,7 +21,7 @@ def build_transform(input_size): | |||
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |||
| T.ToTensor(), | |||
| T.Normalize(mean=MEAN, std=STD), | |||
| ] | |||
| ], | |||
| ) | |||
| return transform | |||
| @@ -42,7 +43,7 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ | |||
| def dynamic_preprocess( | |||
| image, min_num=1, max_num=12, image_size=448, use_thumbnail=False | |||
| image, min_num=1, max_num=12, image_size=448, use_thumbnail=False, | |||
| ): | |||
| orig_width, orig_height = image.size | |||
| aspect_ratio = orig_width / orig_height | |||
| @@ -59,7 +60,7 @@ def dynamic_preprocess( | |||
| # find the closest aspect ratio to the target | |||
| target_aspect_ratio = find_closest_aspect_ratio( | |||
| aspect_ratio, target_ratios, orig_width, orig_height, image_size | |||
| aspect_ratio, target_ratios, orig_width, orig_height, image_size, | |||
| ) | |||
| # calculate the target width and height | |||
| @@ -91,7 +92,7 @@ def load_image(image_array: np.array, input_size=448, max_num=12): | |||
| image = Image.fromarray(image_array).convert("RGB") | |||
| transform = build_transform(input_size=input_size) | |||
| images = dynamic_preprocess( | |||
| image, image_size=input_size, use_thumbnail=True, max_num=max_num | |||
| image, image_size=input_size, use_thumbnail=True, max_num=max_num, | |||
| ) | |||
| pixel_values = [transform(image) for image in images] | |||
| pixel_values = torch.stack(pixel_values) | |||
| @@ -116,7 +117,7 @@ def main(): | |||
| .to(device) | |||
| ) | |||
| tokenizer = AutoTokenizer.from_pretrained( | |||
| model_path, trust_remote_code=True, use_fast=False | |||
| model_path, trust_remote_code=True, use_fast=False, | |||
| ) | |||
| node = Node() | |||
| @@ -138,10 +139,7 @@ def main(): | |||
| width = metadata["width"] | |||
| height = metadata["height"] | |||
| if encoding == "bgr8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding == "rgb8": | |||
| if encoding == "bgr8" or encoding == "rgb8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| else: | |||
| @@ -168,7 +166,7 @@ def main(): | |||
| ) | |||
| generation_config = dict(max_new_tokens=1024, do_sample=True) | |||
| response = model.chat( | |||
| tokenizer, pixel_values, question, generation_config | |||
| tokenizer, pixel_values, question, generation_config, | |||
| ) | |||
| node.send_output( | |||
| "text", | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,7 +1,7 @@ | |||
| from pynput import keyboard | |||
| from pynput.keyboard import Events | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from pynput import keyboard | |||
| from pynput.keyboard import Events | |||
| def main(): | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,9 +1,9 @@ | |||
| import sounddevice as sd | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import time as tm | |||
| import os | |||
| import time as tm | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import sounddevice as sd | |||
| from dora import Node | |||
| MAX_DURATION = float(os.getenv("MAX_DURATION", "0.1")) | |||
| @@ -19,7 +19,6 @@ def main(): | |||
| always_none = node.next(timeout=0.001) is None | |||
| finished = False | |||
| # pylint: disable=unused-argument | |||
| def callback(indata, frames, time, status): | |||
| nonlocal buffer, node, start_recording_time, finished | |||
| @@ -36,7 +35,7 @@ def main(): | |||
| # Start recording | |||
| with sd.InputStream( | |||
| callback=callback, dtype=np.int16, channels=1, samplerate=SAMPLE_RATE | |||
| callback=callback, dtype=np.int16, channels=1, samplerate=SAMPLE_RATE, | |||
| ): | |||
| while not finished: | |||
| sd.sleep(int(1000)) | |||
| sd.sleep(1000) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,11 +1,12 @@ | |||
| from fastapi import FastAPI | |||
| from pydantic import BaseModel | |||
| import ast | |||
| import asyncio | |||
| from typing import List, Optional | |||
| import pyarrow as pa | |||
| import uvicorn | |||
| from dora import Node | |||
| import asyncio | |||
| import pyarrow as pa | |||
| import ast | |||
| from fastapi import FastAPI | |||
| from pydantic import BaseModel | |||
| DORA_RESPONSE_TIMEOUT = 10 | |||
| app = FastAPI() | |||
| @@ -55,13 +56,7 @@ async def create_chat_completion(request: ChatCompletionRequest): | |||
| print("Passing input as string") | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, int): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, float): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, dict): | |||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict): | |||
| data = pa.array([data]) | |||
| else: | |||
| data = pa.array(data) # initialize pyarrow array | |||
| @@ -73,12 +68,10 @@ async def create_chat_completion(request: ChatCompletionRequest): | |||
| if event["type"] == "ERROR": | |||
| response_str = "No response received. Err: " + event["value"][0].as_py() | |||
| break | |||
| elif event["type"] == "INPUT" and event["id"] == "v1/chat/completions": | |||
| if event["type"] == "INPUT" and event["id"] == "v1/chat/completions": | |||
| response = event["value"] | |||
| response_str = response[0].as_py() if response else "No response received" | |||
| break | |||
| else: | |||
| pass | |||
| return ChatCompletionResponse( | |||
| id="chatcmpl-1234", | |||
| @@ -90,7 +83,7 @@ async def create_chat_completion(request: ChatCompletionRequest): | |||
| "index": 0, | |||
| "message": {"role": "assistant", "content": response_str}, | |||
| "finish_reason": "stop", | |||
| } | |||
| }, | |||
| ], | |||
| usage={ | |||
| "prompt_tokens": len(data), | |||
| @@ -110,7 +103,7 @@ async def list_models(): | |||
| "object": "model", | |||
| "created": 1677610602, | |||
| "owned_by": "openai", | |||
| } | |||
| }, | |||
| ], | |||
| } | |||
| @@ -28,3 +28,6 @@ dora-openai-server = "dora_openai_server.main:main" | |||
| [build-system] | |||
| requires = ["poetry-core>=1.8.0"] | |||
| build-backend = "poetry.core.masonry.api" | |||
| [tool.ruff.lint] | |||
| extend-select = ["I"] | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,10 +1,10 @@ | |||
| import os | |||
| from pathlib import Path | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import numpy as np | |||
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |||
| from_code = os.getenv("SOURCE_LANGUAGE", "zh") | |||
| to_code = os.getenv("TARGET_LANGUAGE", "en") | |||
| @@ -29,7 +29,7 @@ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): | |||
| if sum(1 for char in text if "\u4e00" <= char <= "\u9fff") / len(text) > 0.5: | |||
| # Chinese text processing | |||
| for repeat_length in range( | |||
| min_repeat_length, min(max_repeat_length, len(text) // 2) | |||
| min_repeat_length, min(max_repeat_length, len(text) // 2), | |||
| ): | |||
| for i in range(len(text) - repeat_length * 2 + 1): | |||
| chunk1 = text[i : i + repeat_length] | |||
| @@ -41,7 +41,7 @@ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): | |||
| # Non-Chinese (space-separated) text processing | |||
| words = text.split() | |||
| for repeat_length in range( | |||
| min_repeat_length, min(max_repeat_length, len(words) // 2) | |||
| min_repeat_length, min(max_repeat_length, len(words) // 2), | |||
| ): | |||
| for i in range(len(words) - repeat_length * 2 + 1): | |||
| chunk1 = " ".join(words[i : i + repeat_length]) | |||
| @@ -10,16 +10,16 @@ pip install -e . | |||
| ## Contribution Guide | |||
| - Format with [black](https://github.com/psf/black): | |||
| - Format with [ruff](https://docs.astral.sh/ruff/): | |||
| ```bash | |||
| black . # Format | |||
| ruff check . --fix | |||
| ``` | |||
| - Lint with [pylint](https://github.com/pylint-dev/pylint): | |||
| - Lint with ruff: | |||
| ```bash | |||
| pylint --disable=C,R --ignored-modules=cv2 . # Lint | |||
| ruff check . | |||
| ``` | |||
| - Test with [pytest](https://github.com/pytest-dev/pytest) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,5 +1,4 @@ | |||
| from .main import main | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,10 +1,11 @@ | |||
| from dora import Node | |||
| import outetts | |||
| import argparse # Add argparse import | |||
| import pathlib | |||
| import os | |||
| import torch | |||
| import pathlib | |||
| import outetts | |||
| import pyarrow as pa | |||
| import torch | |||
| from dora import Node | |||
| PATH_SPEAKER = os.getenv("PATH_SPEAKER", "speaker.json") | |||
| @@ -45,7 +46,6 @@ def create_speaker(interface, path): | |||
| interface.save_speaker(speaker, "speaker.json") | |||
| print("saved speaker.json") | |||
| return | |||
| def main(arg_list: list[str] | None = None): | |||
| @@ -85,7 +85,7 @@ def main(arg_list: list[str] | None = None): | |||
| f"""Node received: | |||
| id: {event["id"]}, | |||
| value: {event["value"]}, | |||
| metadata: {event["metadata"]}""" | |||
| metadata: {event["metadata"]}""", | |||
| ) | |||
| elif event["id"] == "text": | |||
| @@ -1,7 +1,5 @@ | |||
| import pytest | |||
| from dora_outtetts.main import load_interface | |||
| from dora_outtetts.main import main | |||
| from dora_outtetts.main import load_interface, main | |||
| def test_import_main(): | |||
| @@ -21,8 +21,7 @@ outetts = "^0.2.3" | |||
| [tool.poetry.dev-dependencies] | |||
| pytest = ">= 6.3.4" | |||
| pylint = ">= 3.3.2" | |||
| black = ">= 22.10" | |||
| ruff = ">= 0.9.1" | |||
| [tool.poetry.scripts] | |||
| dora-outtetts = "dora_outtetts.main:main" | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,19 +1,19 @@ | |||
| from threading import Thread | |||
| from dora import Node | |||
| import os | |||
| import time | |||
| from pathlib import Path | |||
| from threading import Thread | |||
| import numpy as np | |||
| import torch | |||
| import time | |||
| import pyaudio | |||
| import torch | |||
| from dora import Node | |||
| from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer | |||
| from transformers import ( | |||
| AutoTokenizer, | |||
| AutoFeatureExtractor, | |||
| set_seed, | |||
| AutoTokenizer, | |||
| StoppingCriteria, | |||
| StoppingCriteriaList, | |||
| set_seed, | |||
| ) | |||
| device = "cuda:0" # if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |||
| @@ -31,7 +31,7 @@ if bool(os.getenv("USE_MODELSCOPE_HUB") in ["True", "true"]): | |||
| MODEL_NAME_OR_PATH = snapshot_download(MODEL_NAME_OR_PATH) | |||
| model = ParlerTTSForConditionalGeneration.from_pretrained( | |||
| MODEL_NAME_OR_PATH, torch_dtype=torch_dtype, low_cpu_mem_usage=True | |||
| MODEL_NAME_OR_PATH, torch_dtype=torch_dtype, low_cpu_mem_usage=True, | |||
| ).to(device) | |||
| model.generation_config.cache_implementation = "static" | |||
| model.forward = torch.compile(model.forward, mode="default") | |||
| @@ -58,7 +58,6 @@ stream = p.open(format=pyaudio.paInt16, channels=1, rate=sampling_rate, output=T | |||
| def play_audio(audio_array): | |||
| if np.issubdtype(audio_array.dtype, np.floating): | |||
| max_val = np.max(np.abs(audio_array)) | |||
| audio_array = (audio_array / max_val) * 32767 | |||
| @@ -73,7 +72,7 @@ class InterruptStoppingCriteria(StoppingCriteria): | |||
| self.stop_signal = False | |||
| def __call__( | |||
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |||
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs, | |||
| ) -> bool: | |||
| return self.stop_signal | |||
| @@ -109,7 +108,6 @@ def generate_base( | |||
| thread.start() | |||
| for new_audio in streamer: | |||
| current_time = time.time() | |||
| print(f"Time between iterations: {round(current_time - prev_time, 2)} seconds") | |||
| @@ -127,7 +125,7 @@ def generate_base( | |||
| if event["id"] == "stop": | |||
| stopping_criteria.stop() | |||
| break | |||
| elif event["id"] == "text": | |||
| if event["id"] == "text": | |||
| stopping_criteria.stop() | |||
| text = event["value"][0].as_py() | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,16 +1,16 @@ | |||
| from piper_sdk import C_PiperInterface | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import numpy as np | |||
| import os | |||
| import time | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from piper_sdk import C_PiperInterface | |||
| TEACH_MODE = os.getenv("TEACH_MODE", "False") in ["True", "true"] | |||
| def enable_fun(piper: C_PiperInterface): | |||
| """ | |||
| 使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序 | |||
| """使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序 | |||
| """ | |||
| enable_flag = False | |||
| # 设置超时时间(秒) | |||
| @@ -144,7 +144,6 @@ def main(): | |||
| ) | |||
| elif event["type"] == "STOP": | |||
| if not TEACH_MODE: | |||
| piper.MotionCtrl_2(0x01, 0x01, 50, 0x00) | |||
| piper.JointCtrl(0, 0, 0, 0, 0, 0) | |||
| @@ -22,16 +22,16 @@ pip install -e . | |||
| ## Contribution Guide | |||
| - Format with [black](https://github.com/psf/black): | |||
| - Format with [ruff](https://docs.astral.sh/ruff/): | |||
| ```bash | |||
| black . # Format | |||
| ruff check . --fix | |||
| ``` | |||
| - Lint with [pylint](https://github.com/pylint-dev/pylint): | |||
| - Lint with ruff: | |||
| ```bash | |||
| pylint --disable=C,R --ignored-modules=cv2 . # Lint | |||
| ruff check . | |||
| ``` | |||
| - Test with [pytest](https://github.com/pytest-dev/pytest) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,5 +1,4 @@ | |||
| from .main import main | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -12,17 +12,16 @@ p = pyaudio.PyAudio() | |||
| def play_audio( | |||
| audio_array: pa.array, sample_rate: int, stream: pyaudio.Stream = None | |||
| audio_array: pa.array, sample_rate: int, stream: pyaudio.Stream = None, | |||
| ) -> pyaudio.Stream: | |||
| """Play audio using pyaudio and replace stream if already exists""" | |||
| if np.issubdtype(audio_array.dtype, np.floating): | |||
| max_val = np.max(np.abs(audio_array)) | |||
| audio_array = (audio_array / max_val) * 32767 | |||
| audio_array = audio_array.astype(np.int16) | |||
| if stream is None: | |||
| stream = p.open( | |||
| format=pyaudio.paInt16, channels=1, rate=sample_rate, output=True | |||
| format=pyaudio.paInt16, channels=1, rate=sample_rate, output=True, | |||
| ) | |||
| stream.write(audio_array.tobytes()) | |||
| return stream | |||
| @@ -18,8 +18,8 @@ pyaudio = ">= 0.1.0" | |||
| [tool.poetry.dev-dependencies] | |||
| pytest = ">= 6.3.4" | |||
| pylint = ">= 2.5.2" | |||
| black = ">= 22.10" | |||
| ruff = ">= 0.9.1" | |||
| [tool.poetry.scripts] | |||
| dora-pyaudio = "dora_pyaudio.main:main" | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -13,21 +13,28 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ****************************************************************************** | |||
| import cv2 | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| try: | |||
| from pyorbbecsdk import Context | |||
| from pyorbbecsdk import Config | |||
| from pyorbbecsdk import OBError | |||
| from pyorbbecsdk import OBSensorType, OBFormat | |||
| from pyorbbecsdk import Pipeline, FrameSet | |||
| from pyorbbecsdk import VideoStreamProfile | |||
| from pyorbbecsdk import VideoFrame | |||
| from pyorbbecsdk import ( | |||
| Config, | |||
| Context, | |||
| FrameSet, | |||
| OBError, | |||
| OBFormat, | |||
| OBSensorType, | |||
| Pipeline, | |||
| VideoFrame, | |||
| VideoStreamProfile, | |||
| ) | |||
| except ImportError as err: | |||
| print( | |||
| "Please install pyorbbecsdk first by following the instruction at: https://github.com/orbbec/pyorbbecsdk" | |||
| "Please install pyorbbecsdk first by following the instruction at: https://github.com/orbbec/pyorbbecsdk", | |||
| ) | |||
| raise err | |||
| @@ -42,7 +49,7 @@ class TemporalFilter: | |||
| result = frame | |||
| else: | |||
| result = cv2.addWeighted( | |||
| frame, self.alpha, self.previous_frame, 1 - self.alpha, 0 | |||
| frame, self.alpha, self.previous_frame, 1 - self.alpha, 0, | |||
| ) | |||
| self.previous_frame = result | |||
| return result | |||
| @@ -115,14 +122,11 @@ def frame_to_bgr_image(frame: VideoFrame): | |||
| image = np.resize(data, (height, width, 2)) | |||
| image = cv2.cvtColor(image, cv2.COLOR_YUV2BGR_UYVY) | |||
| else: | |||
| print("Unsupported color format: {}".format(color_format)) | |||
| print(f"Unsupported color format: {color_format}") | |||
| return None | |||
| return image | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| ESC_KEY = 27 | |||
| MIN_DEPTH_METERS = 0.01 | |||
| MAX_DEPTH_METERS = 15.0 | |||
| @@ -141,7 +145,7 @@ def main(): | |||
| profile_list = pipeline.get_stream_profile_list(OBSensorType.COLOR_SENSOR) | |||
| try: | |||
| color_profile: VideoStreamProfile = profile_list.get_video_stream_profile( | |||
| 640, 480, OBFormat.RGB, 30 | |||
| 640, 480, OBFormat.RGB, 30, | |||
| ) | |||
| except OBError as e: | |||
| print(e) | |||
| @@ -150,7 +154,7 @@ def main(): | |||
| profile_list = pipeline.get_stream_profile_list(OBSensorType.DEPTH_SENSOR) | |||
| try: | |||
| depth_profile: VideoStreamProfile = profile_list.get_video_stream_profile( | |||
| 640, 480, OBFormat.Y16, 30 | |||
| 640, 480, OBFormat.Y16, 30, | |||
| ) | |||
| except OBError as e: | |||
| print(e) | |||
| @@ -200,7 +204,7 @@ def main(): | |||
| node.send_output("depth", storage) | |||
| # Covert to Image | |||
| depth_image = cv2.normalize( | |||
| depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U | |||
| depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U, | |||
| ) | |||
| # Send Depth Image | |||
| depth_image = cv2.applyColorMap(depth_image, cv2.COLORMAP_JET) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -4,11 +4,10 @@ import time | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import pyrealsense2 as rs | |||
| from dora import Node | |||
| RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| import pyrealsense2 as rs | |||
| def main(): | |||
| @@ -26,7 +25,7 @@ def main(): | |||
| serials = [device.get_info(rs.camera_info.serial_number) for device in devices] | |||
| if DEVICE_SERIAL and (DEVICE_SERIAL in serials): | |||
| raise ConnectionError( | |||
| f"Device with serial {DEVICE_SERIAL} not found within: {serials}." | |||
| f"Device with serial {DEVICE_SERIAL} not found within: {serials}.", | |||
| ) | |||
| pipeline = rs.pipeline() | |||
| @@ -53,7 +52,6 @@ def main(): | |||
| pa.array([]) # initialize pyarrow array | |||
| for event in node: | |||
| # Run this example in the CI for 10 seconds only. | |||
| if RUNNER_CI and time.time() - start_time > 10: | |||
| break | |||
| @@ -116,7 +114,7 @@ def main(): | |||
| # metadata["principal_point"] = [int(rgb_intr.ppx), int(rgb_intr.ppy)] | |||
| node.send_output("image", storage, metadata) | |||
| node.send_output( | |||
| "depth", pa.array(scaled_depth_image.ravel()), metadata | |||
| "depth", pa.array(scaled_depth_image.ravel()), metadata, | |||
| ) | |||
| elif event_type == "ERROR": | |||
| @@ -2,7 +2,6 @@ import pytest | |||
| def test_import_main(): | |||
| from dora_pyrealsense.main import main | |||
| # Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow. | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,13 +1,14 @@ | |||
| import os | |||
| from dora import Node | |||
| import torch | |||
| from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |||
| from qwen_vl_utils import process_vision_info | |||
| from pathlib import Path | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import torch | |||
| from dora import Node | |||
| from PIL import Image | |||
| from pathlib import Path | |||
| import cv2 | |||
| from qwen_vl_utils import process_vision_info | |||
| from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |||
| DEFAULT_PATH = "Qwen/Qwen2-VL-2B-Instruct" | |||
| @@ -27,7 +28,7 @@ ADAPTER_PATH = os.getenv("ADAPTER_PATH", "") | |||
| # Check if flash_attn is installed | |||
| try: | |||
| import flash_attn as _ | |||
| import flash_attn as _ # noqa | |||
| model = Qwen2VLForConditionalGeneration.from_pretrained( | |||
| MODEL_NAME_OR_PATH, | |||
| @@ -52,10 +53,8 @@ processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH) | |||
| def generate(frames: dict, question): | |||
| """Generate the response to the question given the image using Qwen2 model. | |||
| """ | |||
| Generate the response to the question given the image using Qwen2 model. | |||
| """ | |||
| messages = [ | |||
| { | |||
| "role": "user", | |||
| @@ -69,12 +68,12 @@ def generate(frames: dict, question): | |||
| + [ | |||
| {"type": "text", "text": question}, | |||
| ], | |||
| } | |||
| }, | |||
| ] | |||
| # Preparation for inference | |||
| text = processor.apply_chat_template( | |||
| messages, tokenize=False, add_generation_prompt=True | |||
| messages, tokenize=False, add_generation_prompt=True, | |||
| ) | |||
| image_inputs, video_inputs = process_vision_info(messages) | |||
| inputs = processor( | |||
| @@ -118,7 +117,6 @@ def main(): | |||
| event_type = event["type"] | |||
| if event_type == "INPUT": | |||
| event_id = event["id"] | |||
| if "image" in event_id: | |||
| @@ -128,13 +126,7 @@ def main(): | |||
| width = metadata["width"] | |||
| height = metadata["height"] | |||
| if encoding == "bgr8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding == "rgb8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: | |||
| if encoding == "bgr8" or encoding == "rgb8" or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| else: | |||
| @@ -1 +1 @@ | |||
| Subproject commit 198374ea8c4a2ec2ddae86c35448d21aa9756f37 | |||
| Subproject commit b2889e65cfe62571ced3ce88f00e7d80b41fee69 | |||
| @@ -7,7 +7,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,21 +1,22 @@ | |||
| # install dependencies as shown in the README here https://github.com/alik-git/RoboticsDiffusionTransformer?tab=readme-ov-file#installation | |||
| import yaml | |||
| import torch | |||
| import os | |||
| from pathlib import Path | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import torch | |||
| import yaml | |||
| from dora import Node | |||
| from PIL import Image | |||
| from dora_rdt_1b.RoboticsDiffusionTransformer.configs.state_vec import ( | |||
| STATE_VEC_IDX_MAPPING, | |||
| ) | |||
| from dora import Node | |||
| import cv2 | |||
| import pyarrow as pa | |||
| import os | |||
| from pathlib import Path | |||
| VISION_DEFAULT_PATH = "robotics-diffusion-transformer/rdt-1b" | |||
| ROBOTIC_MODEL_NAME_OR_PATH = os.getenv( | |||
| "ROBOTIC_MODEL_NAME_OR_PATH", VISION_DEFAULT_PATH | |||
| "ROBOTIC_MODEL_NAME_OR_PATH", VISION_DEFAULT_PATH, | |||
| ) | |||
| LANGUAGE_EMBEDDING_PATH = os.getenv("LANGUAGE_EMBEDDING", "lang_embed.pt") | |||
| @@ -33,7 +34,7 @@ config_path = ( | |||
| file_path / "RoboticsDiffusionTransformer/configs/base.yaml" | |||
| ) # default config | |||
| with open(config_path, "r", encoding="utf-8") as fp: | |||
| with open(config_path, encoding="utf-8") as fp: | |||
| config = yaml.safe_load(fp) | |||
| @@ -74,7 +75,7 @@ def get_language_embeddings(): | |||
| ) | |||
| return lang_embeddings.unsqueeze( | |||
| 0 | |||
| 0, | |||
| ) # Size: (B, L_lang, D) or None, language condition tokens (variable length), dimension D is assumed to be the same as the hidden size. | |||
| @@ -82,14 +83,13 @@ def expand2square(pil_img, background_color): | |||
| width, height = pil_img.size | |||
| if width == height: | |||
| return pil_img | |||
| elif width > height: | |||
| if width > height: | |||
| result = Image.new(pil_img.mode, (width, width), background_color) | |||
| result.paste(pil_img, (0, (width - height) // 2)) | |||
| return result | |||
| else: | |||
| result = Image.new(pil_img.mode, (height, height), background_color) | |||
| result.paste(pil_img, ((height - width) // 2, 0)) | |||
| return result | |||
| result = Image.new(pil_img.mode, (height, height), background_color) | |||
| result.paste(pil_img, ((height - width) // 2, 0)) | |||
| return result | |||
| def process_image(rgbs_lst, image_processor, vision_encoder): | |||
| @@ -156,12 +156,12 @@ def get_states(proprio): | |||
| B, N = 1, 1 # batch size and state history size | |||
| states = torch.zeros( | |||
| (B, N, config["model"]["state_token_dim"]), device=DEVICE, dtype=DTYPE | |||
| (B, N, config["model"]["state_token_dim"]), device=DEVICE, dtype=DTYPE, | |||
| ) | |||
| # suppose you do not have proprio | |||
| # it's kind of tricky, I strongly suggest adding proprio as input and further fine-tuning | |||
| proprio = torch.tensor(proprio, device=DEVICE, dtype=DTYPE).reshape( | |||
| (1, 1, -1) | |||
| (1, 1, -1), | |||
| ) # B, N = 1, 1 # batch size and state history size | |||
| # if you have proprio, you can do like this | |||
| @@ -170,19 +170,19 @@ def get_states(proprio): | |||
| states[:, :, STATE_INDICES] = proprio | |||
| state_elem_mask = torch.zeros( | |||
| (1, config["model"]["state_token_dim"]), device=DEVICE, dtype=torch.bool | |||
| (1, config["model"]["state_token_dim"]), device=DEVICE, dtype=torch.bool, | |||
| ) | |||
| state_elem_mask[:, STATE_INDICES] = True | |||
| states, state_elem_mask = states.to(DEVICE, dtype=DTYPE), state_elem_mask.to( | |||
| DEVICE, dtype=DTYPE | |||
| states, state_elem_mask = ( | |||
| states.to(DEVICE, dtype=DTYPE), | |||
| state_elem_mask.to(DEVICE, dtype=DTYPE), | |||
| ) | |||
| states = states[:, -1:, :] # only use the last state | |||
| return states, state_elem_mask, STATE_INDICES | |||
| def main(): | |||
| rdt = get_policy() | |||
| lang_embeddings = get_language_embeddings() | |||
| vision_encoder, image_processor = get_vision_model() | |||
| @@ -195,11 +195,9 @@ def main(): | |||
| frames = {} | |||
| joints = {} | |||
| with torch.no_grad(): | |||
| for event in node: | |||
| event_type = event["type"] | |||
| if event_type == "INPUT": | |||
| event_id = event["id"] | |||
| if "image" in event_id: | |||
| @@ -207,13 +205,7 @@ def main(): | |||
| metadata = event["metadata"] | |||
| encoding = metadata["encoding"] | |||
| if encoding == "bgr8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding == "rgb8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: | |||
| if encoding == "bgr8" or encoding == "rgb8" or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| else: | |||
| @@ -243,13 +235,13 @@ def main(): | |||
| else: | |||
| raise RuntimeError(f"Unsupported image encoding: {encoding}") | |||
| frames[f"last_{event_id}"] = frames.get( | |||
| event_id, Image.fromarray(frame) | |||
| event_id, Image.fromarray(frame), | |||
| ) | |||
| frames[event_id] = Image.fromarray(frame) | |||
| elif "jointstate" in event_id: | |||
| joints[event_id] = event["value"].to_numpy() | |||
| elif "tick" == event_id: | |||
| elif event_id == "tick": | |||
| ## Wait for all images | |||
| if len(frames.keys()) < 6: | |||
| continue | |||
| @@ -270,7 +262,7 @@ def main(): | |||
| ], | |||
| ] | |||
| image_embeds = process_image( | |||
| rgbs_lst, image_processor, vision_encoder | |||
| rgbs_lst, image_processor, vision_encoder, | |||
| ) | |||
| ## Embed states | |||
| @@ -278,26 +270,26 @@ def main(): | |||
| [ | |||
| joints["jointstate_left"], | |||
| joints["jointstate_right"], | |||
| ] | |||
| ], | |||
| ) | |||
| states, state_elem_mask, state_indices = get_states(proprio=proprio) | |||
| actions = rdt.predict_action( | |||
| lang_tokens=lang_embeddings, | |||
| lang_attn_mask=torch.ones( | |||
| lang_embeddings.shape[:2], dtype=torch.bool, device=DEVICE | |||
| lang_embeddings.shape[:2], dtype=torch.bool, device=DEVICE, | |||
| ), | |||
| img_tokens=image_embeds, | |||
| state_tokens=states, # how can I get this? | |||
| action_mask=state_elem_mask.unsqueeze(1), # how can I get this? | |||
| ctrl_freqs=torch.tensor( | |||
| [25.0], device=DEVICE | |||
| [25.0], device=DEVICE, | |||
| ), # would this default work? | |||
| ) # (1, chunk_size, 128) | |||
| # select the meaning action via STATE_INDICES | |||
| action = actions[ | |||
| :, :, state_indices | |||
| :, :, state_indices, | |||
| ] # (1, chunk_size, len(STATE_INDICES)) = (1, chunk_size, 7+ 1) | |||
| action = action.detach().float().to("cpu").numpy() | |||
| node.send_output("action", pa.array(action.ravel())) | |||
| @@ -29,12 +29,9 @@ huggingface_hub = "0.23.5" | |||
| # flash_attn = "^2.6.1" # Install using: pip install -U flash-attn --no-build-isolation | |||
| [tool.pylint.MASTER] | |||
| ignore-paths = '^dora_rdt_1b/RoboticsDiffusionTransformer.*$' | |||
| [tool.poetry.dev-dependencies] | |||
| pytest = "^8.3.4" | |||
| pylint = "^3.3.2" | |||
| ruff = ">= 0.9.1" | |||
| [tool.black] | |||
| extend-exclude = 'dora_rdt_1b/RoboticsDiffusionTransformer' | |||
| @@ -46,3 +43,6 @@ dora-rdt-1b = "dora_rdt_1b.main:main" | |||
| [build-system] | |||
| requires = ["poetry-core>=1.8.0"] | |||
| build-backend = "poetry.core.masonry.api" | |||
| [tool.ruff] | |||
| exclude = ["dora_rdt_1b/RoboticsDiffusionTransformer"] | |||
| @@ -1,10 +1,10 @@ | |||
| import os | |||
| import numpy as np | |||
| import pytest | |||
| import torch | |||
| import numpy as np | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| import os | |||
| CI = os.environ.get("CI") | |||
| @@ -20,8 +20,8 @@ def test_import_main(): | |||
| # Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow. | |||
| # with pytest.raises(RuntimeError): | |||
| # main() | |||
| import dora_rdt_1b.RoboticsDiffusionTransformer as _ | |||
| import dora_rdt_1b as _ | |||
| import dora_rdt_1b.RoboticsDiffusionTransformer as _ # noqa | |||
| import dora_rdt_1b as _ # noqa | |||
| def test_download_policy(): | |||
| @@ -44,7 +44,6 @@ def test_download_vision_model(): | |||
| def test_download_language_embeddings(): | |||
| ## in the future we should add this test within CI | |||
| if CI: | |||
| return | |||
| @@ -55,7 +54,6 @@ def test_download_language_embeddings(): | |||
| def test_load_dummy_image(): | |||
| from dora_rdt_1b.main import config | |||
| # Load pretrained model (in HF style) | |||
| @@ -85,7 +83,7 @@ def test_load_dummy_image(): | |||
| # image pre-processing | |||
| # The background image used for padding | |||
| background_color = np.array( | |||
| [int(x * 255) for x in image_processor.image_mean], dtype=np.uint8 | |||
| [int(x * 255) for x in image_processor.image_mean], dtype=np.uint8, | |||
| ).reshape((1, 1, 3)) | |||
| background_image = ( | |||
| np.ones( | |||
| @@ -119,21 +117,20 @@ def test_load_dummy_image(): | |||
| width, height = pil_img.size | |||
| if width == height: | |||
| return pil_img | |||
| elif width > height: | |||
| if width > height: | |||
| result = Image.new( | |||
| pil_img.mode, (width, width), background_color | |||
| pil_img.mode, (width, width), background_color, | |||
| ) | |||
| result.paste(pil_img, (0, (width - height) // 2)) | |||
| return result | |||
| else: | |||
| result = Image.new( | |||
| pil_img.mode, (height, height), background_color | |||
| ) | |||
| result.paste(pil_img, ((height - width) // 2, 0)) | |||
| return result | |||
| result = Image.new( | |||
| pil_img.mode, (height, height), background_color, | |||
| ) | |||
| result.paste(pil_img, ((height - width) // 2, 0)) | |||
| return result | |||
| image = expand2square( | |||
| image, tuple(int(x * 255) for x in image_processor.image_mean) | |||
| image, tuple(int(x * 255) for x in image_processor.image_mean), | |||
| ) | |||
| image = image_processor.preprocess(image, return_tensors="pt")[ | |||
| "pixel_values" | |||
| @@ -144,7 +141,7 @@ def test_load_dummy_image(): | |||
| # encode images | |||
| image_embeds = vision_encoder(image_tensor).detach() | |||
| pytest.image_embeds = image_embeds.reshape( | |||
| -1, vision_encoder.hidden_size | |||
| -1, vision_encoder.hidden_size, | |||
| ).unsqueeze(0) | |||
| @@ -159,7 +156,7 @@ def test_dummy_states(): | |||
| # it's kind of tricky, I strongly suggest adding proprio as input and further fine-tuning | |||
| B, N = 1, 1 # batch size and state history size | |||
| states = torch.zeros( | |||
| (B, N, config["model"]["state_token_dim"]), device=DEVICE, dtype=DTYPE | |||
| (B, N, config["model"]["state_token_dim"]), device=DEVICE, dtype=DTYPE, | |||
| ) | |||
| # if you have proprio, you can do like this | |||
| @@ -168,7 +165,7 @@ def test_dummy_states(): | |||
| # states[:, :, STATE_INDICES] = proprio | |||
| state_elem_mask = torch.zeros( | |||
| (B, config["model"]["state_token_dim"]), device=DEVICE, dtype=torch.bool | |||
| (B, config["model"]["state_token_dim"]), device=DEVICE, dtype=torch.bool, | |||
| ) | |||
| from dora_rdt_1b.RoboticsDiffusionTransformer.configs.state_vec import ( | |||
| STATE_VEC_IDX_MAPPING, | |||
| @@ -187,8 +184,9 @@ def test_dummy_states(): | |||
| ] | |||
| state_elem_mask[:, STATE_INDICES] = True | |||
| states, state_elem_mask = states.to(DEVICE, dtype=DTYPE), state_elem_mask.to( | |||
| DEVICE, dtype=DTYPE | |||
| states, state_elem_mask = ( | |||
| states.to(DEVICE, dtype=DTYPE), | |||
| state_elem_mask.to(DEVICE, dtype=DTYPE), | |||
| ) | |||
| states = states[:, -1:, :] # only use the last state | |||
| pytest.states = states | |||
| @@ -211,7 +209,7 @@ def test_dummy_input(): | |||
| actions = rdt.predict_action( | |||
| lang_tokens=lang_embeddings, | |||
| lang_attn_mask=torch.ones( | |||
| lang_embeddings.shape[:2], dtype=torch.bool, device=DEVICE | |||
| lang_embeddings.shape[:2], dtype=torch.bool, device=DEVICE, | |||
| ), | |||
| img_tokens=image_embeds, | |||
| state_tokens=states, # how can I get this? | |||
| @@ -221,6 +219,6 @@ def test_dummy_input(): | |||
| # select the meaning action via STATE_INDICES | |||
| action = actions[ | |||
| :, :, STATE_INDICES | |||
| :, :, STATE_INDICES, | |||
| ] # (1, chunk_size, len(STATE_INDICES)) = (1, chunk_size, 7+ 1) | |||
| print(action) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -3,14 +3,15 @@ try: | |||
| from ugv_sdk_py import hunter_robot | |||
| except ImportError as err: | |||
| print( | |||
| "Please install ugv_sdk_py first by following the instruction at: https://github.com/westonrobot/ugv_sdk/tree/main?tab=readme-ov-file#build-the-package-as-a-python-package" | |||
| "Please install ugv_sdk_py first by following the instruction at: https://github.com/westonrobot/ugv_sdk/tree/main?tab=readme-ov-file#build-the-package-as-a-python-package", | |||
| ) | |||
| raise err | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import os | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| def main(): | |||
| # Create an instance of HunterRobot | |||
| @@ -41,7 +42,7 @@ def main(): | |||
| [ | |||
| state.motion_state.linear_velocity, | |||
| state.motion_state.angular_velocity, | |||
| ] | |||
| ], | |||
| ), | |||
| ) | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,9 +1,10 @@ | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import numpy as np | |||
| import os | |||
| from silero_vad import load_silero_vad, get_speech_timestamps | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import torch | |||
| from dora import Node | |||
| from silero_vad import get_speech_timestamps, load_silero_vad | |||
| model = load_silero_vad() | |||
| MIN_SILENCE_DURATION_MS = int(os.getenv("MIN_SILENCE_DURATION_MS", "200")) | |||
| @@ -38,14 +39,12 @@ def main(): | |||
| len(speech_timestamps) > 0 | |||
| and len(last_audios) > MIN_AUDIO_SAMPLING_DURAION_S | |||
| ): | |||
| # Check if the audio is not cut at the end. And only return if there is a long time spent | |||
| if speech_timestamps[-1]["end"] == len(audio): | |||
| continue | |||
| else: | |||
| audio = audio[0 : speech_timestamps[-1]["end"]] | |||
| node.send_output("audio", pa.array(audio)) | |||
| last_audios = [audio[speech_timestamps[-1]["end"] :]] | |||
| audio = audio[0 : speech_timestamps[-1]["end"]] | |||
| node.send_output("audio", pa.array(audio)) | |||
| last_audios = [audio[speech_timestamps[-1]["end"] :]] | |||
| # If there is no sound for too long return the audio | |||
| elif len(last_audios) > 75: | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -3,15 +3,14 @@ import os | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from ultralytics import YOLO | |||
| from dora import Node | |||
| from ultralytics import YOLO | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser( | |||
| description="UltraLytics YOLO: This node is used to perform object detection using the UltraLytics YOLO model." | |||
| description="UltraLytics YOLO: This node is used to perform object detection using the UltraLytics YOLO model.", | |||
| ) | |||
| parser.add_argument( | |||
| @@ -52,10 +51,7 @@ def main(): | |||
| width = metadata["width"] | |||
| height = metadata["height"] | |||
| if encoding == "bgr8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding == "rgb8": | |||
| if encoding == "bgr8" or encoding == "rgb8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| else: | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,11 +1,12 @@ | |||
| import os | |||
| import json | |||
| from dora import Node | |||
| import os | |||
| from pathlib import Path | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from PIL import Image | |||
| from pathlib import Path | |||
| import cv2 | |||
| DEFAULT_QUESTION = os.getenv( | |||
| "DEFAULT_QUESTION", | |||
| @@ -14,13 +15,14 @@ DEFAULT_QUESTION = os.getenv( | |||
| def write_dict_to_json(file_path, key: str, new_data): | |||
| """ | |||
| Writes a dictionary to a JSON file. If the file already contains a list of entries, | |||
| """Writes a dictionary to a JSON file. If the file already contains a list of entries, | |||
| the new data will be appended to that list. Otherwise, it will create a new list. | |||
| Parameters: | |||
| Parameters | |||
| ---------- | |||
| - file_path: str, the path to the JSON file. | |||
| - new_data: dict, the dictionary to add to the JSON file. | |||
| """ | |||
| try: | |||
| # Open the JSON file and load its content | |||
| @@ -43,22 +45,22 @@ def write_dict_to_json(file_path, key: str, new_data): | |||
| def save_image_and_add_to_json( | |||
| frame_dict: dict, root_path, llama_root_path, jsonl_file, messages | |||
| frame_dict: dict, root_path, llama_root_path, jsonl_file, messages, | |||
| ): | |||
| """ | |||
| Saves an image from a NumPy array and adds a new JSON object as a line to a JSONL file. | |||
| """Saves an image from a NumPy array and adds a new JSON object as a line to a JSONL file. | |||
| The function generates a sequential numeric image filename starting from 0 and | |||
| follows the provided template structure. | |||
| Parameters: | |||
| Parameters | |||
| ---------- | |||
| - image_array: numpy.ndarray, the image data as a NumPy array. | |||
| - root_path: str, the root directory where the image will be saved. | |||
| - jsonl_file: str, the path to the JSONL file. | |||
| - messages: list of dicts, each containing 'content' and 'role'. | |||
| The image is saved as a PNG file, and the JSONL entry includes the 'messages' and 'images' keys. | |||
| """ | |||
| """ | |||
| # Create the root directory if it doesn't exist | |||
| os.makedirs(llama_root_path / root_path, exist_ok=True) | |||
| @@ -68,7 +70,7 @@ def save_image_and_add_to_json( | |||
| name | |||
| for name in os.listdir(llama_root_path / root_path) | |||
| if os.path.isfile(os.path.join(llama_root_path / root_path, name)) | |||
| ] | |||
| ], | |||
| ) | |||
| image_paths = [] | |||
| for event_id, data in frame_dict.items(): | |||
| @@ -94,9 +96,9 @@ def main(): | |||
| pa.array([]) # initialize pyarrow array | |||
| node = Node() | |||
| assert os.getenv( | |||
| "LLAMA_FACTORY_ROOT_PATH" | |||
| ), "LLAMA_FACTORY_ROOT_PATH is not set, Either git clone the repo or set the environment variable" | |||
| assert os.getenv("LLAMA_FACTORY_ROOT_PATH"), ( | |||
| "LLAMA_FACTORY_ROOT_PATH is not set, Either git clone the repo or set the environment variable" | |||
| ) | |||
| llama_factory_root_path = Path(os.getenv("LLAMA_FACTORY_ROOT_PATH")) / "data" | |||
| entry_name = os.getenv("ENTRY_NAME", "dora_demo") | |||
| @@ -141,13 +143,7 @@ def main(): | |||
| width = metadata["width"] | |||
| height = metadata["height"] | |||
| if encoding == "bgr8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding == "rgb8": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| elif encoding == "jpeg": | |||
| if encoding == "bgr8" or encoding == "rgb8" or encoding == "jpeg": | |||
| channels = 3 | |||
| storage_type = np.uint8 | |||
| else: | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -4,7 +4,6 @@ import os | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| @@ -76,10 +75,9 @@ def yuv420p_to_bgr_opencv(yuv_array, width, height): | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser( | |||
| description="OpenCV Plotter: This node is used to plot text and bounding boxes on an image." | |||
| description="OpenCV Plotter: This node is used to plot text and bounding boxes on an image.", | |||
| ) | |||
| parser.add_argument( | |||
| @@ -118,7 +116,7 @@ def main(): | |||
| plot_height = int(plot_height) | |||
| node = Node( | |||
| args.name | |||
| args.name, | |||
| ) # provide the name to connect to the dataflow if dynamic node | |||
| plot = Plot() | |||
| @@ -168,7 +166,6 @@ def main(): | |||
| plot.frame = cv2.imdecode(storage, cv2.IMREAD_COLOR) | |||
| elif encoding == "yuv420": | |||
| storage = storage.to_numpy() | |||
| # Convert back to BGR results in more saturated image. | |||
| @@ -201,7 +198,7 @@ def main(): | |||
| y + h / 2, | |||
| ) | |||
| for [x, y, w, h] in original_bbox | |||
| ] | |||
| ], | |||
| ) | |||
| else: | |||
| raise RuntimeError(f"Unsupported bbox format: {bbox_format}") | |||
| @@ -210,7 +207,7 @@ def main(): | |||
| "bbox": bbox, | |||
| "conf": arrow_bbox["conf"].values.to_numpy(), | |||
| "labels": arrow_bbox["labels"].values.to_numpy( | |||
| zero_copy_only=False | |||
| zero_copy_only=False, | |||
| ), | |||
| } | |||
| elif event_id == "text": | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -5,7 +5,6 @@ import time | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| @@ -16,7 +15,7 @@ FLIP = os.getenv("FLIP", "") | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser( | |||
| description="OpenCV Video Capture: This node is used to capture video from a camera." | |||
| description="OpenCV Video Capture: This node is used to capture video from a camera.", | |||
| ) | |||
| parser.add_argument( | |||
| @@ -77,7 +76,6 @@ def main(): | |||
| pa.array([]) # initialize pyarrow array | |||
| for event in node: | |||
| # Run this example in the CI for 10 seconds only. | |||
| if RUNNER_CI and time.time() - start_time > 10: | |||
| break | |||
| @@ -95,7 +93,7 @@ def main(): | |||
| cv2.putText( | |||
| frame, | |||
| f"Error: no frame for camera at path {video_capture_path}.", | |||
| (int(30), int(30)), | |||
| (30, 30), | |||
| cv2.FONT_HERSHEY_SIMPLEX, | |||
| 0.50, | |||
| (255, 255, 255), | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,6 +1,6 @@ | |||
| import argparse | |||
| import os | |||
| import ast | |||
| import os | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| @@ -9,7 +9,6 @@ RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser(description="Simple arrow sender") | |||
| @@ -33,18 +32,14 @@ def main(): | |||
| data = os.getenv("DATA", args.data) | |||
| node = Node( | |||
| args.name | |||
| args.name, | |||
| ) # provide the name to connect to the dataflow if dynamic node | |||
| data = ast.literal_eval(data) | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, int): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, float): | |||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float): | |||
| data = pa.array([data]) | |||
| else: | |||
| data = pa.array(data) # initialize pyarrow array | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,16 +1,14 @@ | |||
| import argparse | |||
| import os | |||
| import ast | |||
| import os | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser(description="Simple arrow sender") | |||
| @@ -34,29 +32,24 @@ def main(): | |||
| data = os.getenv("DATA", args.data) | |||
| node = Node( | |||
| args.name | |||
| args.name, | |||
| ) # provide the name to connect to the dataflow if dynamic node | |||
| if data is None: | |||
| raise ValueError( | |||
| "No data provided. Please specify `DATA` environment argument or as `--data` argument" | |||
| "No data provided. Please specify `DATA` environment argument or as `--data` argument", | |||
| ) | |||
| try: | |||
| data = ast.literal_eval(data) | |||
| except ValueError: | |||
| print("Passing input as string") | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float): | |||
| data = pa.array([data]) | |||
| else: | |||
| try: | |||
| data = ast.literal_eval(data) | |||
| except ValueError: | |||
| print("Passing input as string") | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, int): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, float): | |||
| data = pa.array([data]) | |||
| else: | |||
| data = pa.array(data) # initialize pyarrow array | |||
| node.send_output("data", data) | |||
| data = pa.array(data) # initialize pyarrow array | |||
| node.send_output("data", data) | |||
| if __name__ == "__main__": | |||
| @@ -5,7 +5,7 @@ readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.m | |||
| # Read the content of the README file | |||
| try: | |||
| with open(readme_path, "r", encoding="utf-8") as f: | |||
| with open(readme_path, encoding="utf-8") as f: | |||
| __doc__ = f.read() | |||
| except FileNotFoundError: | |||
| __doc__ = "README file not found." | |||
| @@ -1,16 +1,15 @@ | |||
| import argparse | |||
| import os | |||
| import ast | |||
| import os | |||
| import time | |||
| import pyarrow as pa | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| RUNNER_CI = True if os.getenv("CI") == "true" else False | |||
| def main(): | |||
| # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. | |||
| parser = argparse.ArgumentParser(description="Simple arrow sender") | |||
| @@ -36,7 +35,7 @@ def main(): | |||
| while True: | |||
| try: | |||
| node = Node( | |||
| args.name | |||
| args.name, | |||
| ) # provide the name to connect to the dataflow if dynamic node | |||
| except RuntimeError as err: | |||
| if err != last_err: | |||
| @@ -58,13 +57,7 @@ def main(): | |||
| print("Passing input as string") | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, int): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, float): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, dict): | |||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict): | |||
| data = pa.array([data]) | |||
| else: | |||
| data = pa.array(data) # initialize pyarrow array | |||
| @@ -82,13 +75,7 @@ def main(): | |||
| print("Passing input as string") | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, int): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, float): | |||
| data = pa.array([data]) | |||
| elif isinstance(data, dict): | |||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict): | |||
| data = pa.array([data]) | |||
| else: | |||
| data = pa.array(data) # initialize pyarrow array | |||