Browse Source

Refactoring `python-node-api`

tags/v0.0.0-test.4
haixuanTao 3 years ago
parent
commit
1903c636cb
2 changed files with 10 additions and 19 deletions
  1. +1
    -1
      apis/python/node/Cargo.toml
  2. +9
    -18
      apis/python/node/src/lib.rs

+ 1
- 1
apis/python/node/Cargo.toml View File

@@ -2,6 +2,7 @@
name = "dora-node-api-python"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

@@ -9,7 +10,6 @@ edition = "2021"
dora-node-api = { path = "../../rust/node" }
pyo3 = "0.16"
eyre = "0.6"
pollster = "0.2"
futures = "0.3.21"
tokio = { version = "1.17.0", features = ["rt", "sync", "macros"] }
serde_yaml = "0.8.23"


+ 9
- 18
apis/python/node/src/lib.rs View File

@@ -8,8 +8,8 @@ use std::sync::Arc;
use std::thread;
use tokio::sync::mpsc;
use tokio::sync::mpsc::{Receiver, Sender};

#[pyclass]
// #[repr(transparent)]
pub struct PyDoraNode {
// pub node: DoraNode,
pub rx_input: Receiver<Input>,
@@ -28,8 +28,8 @@ impl IntoPy<PyObject> for PyInput {
impl PyDoraNode {
#[staticmethod]
pub fn init_from_env() -> Self {
let (tx_input, rx_input) = mpsc::channel(10);
let (tx_output, mut rx_output) = mpsc::channel::<(String, Vec<u8>)>(10);
let (tx_input, rx_input) = mpsc::channel(1);
let (tx_output, mut rx_output) = mpsc::channel::<(String, Vec<u8>)>(1);

// Dispatching a tokio threadpool enables us to conveniently use Dora Future stream
// through tokio channel.
@@ -41,18 +41,14 @@ impl PyDoraNode {
let _node = node.clone();
let receive_handle = tokio::spawn(async move {
let mut inputs = _node.inputs().await.unwrap();
loop {
if let Some(input) = inputs.next().await {
tx_input.send(input).await.unwrap()
};
while let Some(input) = inputs.next().await {
tx_input.send(input).await.unwrap()
}
});
let send_handle = tokio::spawn(async move {
loop {
if let Some((output_str, data)) = rx_output.recv().await {
let output_id = DataId::from(output_str);
node.send_output(&output_id, data.as_slice()).await.unwrap()
};
while let Some((output_str, data)) = rx_output.recv().await {
let output_id = DataId::from(output_str);
node.send_output(&output_id, data.as_slice()).await.unwrap()
}
});
let (_, _) = tokio::join!(receive_handle, send_handle);
@@ -70,11 +66,7 @@ impl PyDoraNode {
}

pub fn __next__(&mut self) -> PyResult<Option<PyInput>> {
if let Some(input) = self.rx_input.blocking_recv() {
Ok(Some(PyInput(input)))
} else {
Ok(None)
}
Ok(self.rx_input.blocking_recv().map(PyInput))
}

pub fn send_output(&self, output_str: String, data: Vec<u8>) -> () {
@@ -85,7 +77,6 @@ impl PyDoraNode {
}
}

/// This module is implemented in Rust.
#[pymodule]
fn dora(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDoraNode>().unwrap();


Loading…
Cancel
Save