Browse Source

Add dora-openai-websocket

make-qwen-llm-configurable
haixuantao 5 months ago
parent
commit
733c86ae39
4 changed files with 936 additions and 34 deletions
  1. +432
    -34
      Cargo.lock
  2. +1
    -0
      Cargo.toml
  3. +29
    -0
      node-hub/dora-openai-websocket/Cargo.toml
  4. +474
    -0
      node-hub/dora-openai-websocket/src/main.rs

+ 432
- 34
Cargo.lock View File

@@ -264,6 +264,12 @@ dependencies = [
"libc",
]

[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"

[[package]]
name = "ansi_colours"
version = "1.2.3"
@@ -666,7 +672,7 @@ dependencies = [
"enumflags2",
"futures-channel",
"futures-util",
"rand 0.9.1",
"rand 0.9.2",
"raw-window-handle 0.6.2",
"serde",
"serde_repr",
@@ -713,6 +719,29 @@ dependencies = [
"syn 2.0.101",
]

[[package]]
name = "assert2"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d6c710e60d14b07d8f42d0e702b16120865eea39edb751e75cd6bf401d18f14"
dependencies = [
"assert2-macros",
"diff",
"yansi",
]

[[package]]
name = "assert2-macros"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9008cbbba9e1d655538870b91fd93814bd82e6968f27788fc734375120ac6f57"
dependencies = [
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.101",
]

[[package]]
name = "assert_matches"
version = "1.5.0"
@@ -1179,14 +1208,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
"axum-core 0.4.5",
"bytes",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"itoa",
"matchit",
"matchit 0.7.3",
"memchr",
"mime",
"percent-encoding",
@@ -1199,6 +1228,40 @@ dependencies = [
"tower-service",
]

[[package]]
name = "axum"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [
"axum-core 0.5.2",
"bytes",
"form_urlencoded",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"itoa",
"matchit 0.8.4",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
]

[[package]]
name = "axum-core"
version = "0.4.5"
@@ -1219,6 +1282,26 @@ dependencies = [
"tower-service",
]

[[package]]
name = "axum-core"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]

[[package]]
name = "az"
version = "1.2.1"
@@ -1240,6 +1323,12 @@ dependencies = [
"windows-targets 0.52.6",
]

[[package]]
name = "base"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a956d500c2380c818e09d3d7c79ba4a1d7fc6354464f1fceaa5705483a29930"

[[package]]
name = "base64"
version = "0.13.1"
@@ -1640,7 +1729,7 @@ dependencies = [
"metal 0.27.0",
"num-traits",
"num_cpus",
"rand 0.9.1",
"rand 0.9.2",
"rand_distr",
"rayon",
"safetensors",
@@ -1732,6 +1821,12 @@ dependencies = [
"thiserror 1.0.69",
]

[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"

[[package]]
name = "cc"
version = "1.2.17"
@@ -1867,6 +1962,33 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901"

[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]

[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"

[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]

[[package]]
name = "cipher"
version = "0.4.4"
@@ -2303,6 +2425,42 @@ dependencies = [
"cfg-if 1.0.0",
]

[[package]]
name = "criterion"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb"
dependencies = [
"anes",
"atty",
"cast",
"ciborium",
"clap 3.2.25",
"criterion-plot",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]

[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]

[[package]]
name = "crossbeam"
version = "0.8.4"
@@ -2809,6 +2967,12 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c"

[[package]]
name = "diff"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"

[[package]]
name = "digest"
version = "0.10.7"
@@ -3266,6 +3430,33 @@ dependencies = [
"uuid 1.16.0",
]

[[package]]
name = "dora-openai-websocket"
version = "0.1.0"
dependencies = [
"anyhow",
"assert2",
"axum 0.8.4",
"base",
"base64 0.22.1",
"bytes",
"criterion",
"dora-cli",
"dora-node-api",
"fastwebsockets",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"rand 0.9.2",
"rustls-pemfile 1.0.4",
"serde",
"serde_json",
"tokio",
"tokio-rustls 0.24.1",
"trybuild",
"webpki-roots 0.23.1",
]

[[package]]
name = "dora-operator-api"
version = "0.3.12"
@@ -3362,7 +3553,7 @@ dependencies = [
"ndarray 0.15.6",
"pinyin",
"pyo3",
"rand 0.9.1",
"rand 0.9.2",
"rerun",
"tokio",
]
@@ -4161,6 +4352,26 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"

[[package]]
name = "fastwebsockets"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "305d3ba574508e27190906d11707dad683e0494e6b85eae9b044cb2734a5e422"
dependencies = [
"base64 0.21.7",
"bytes",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"pin-project",
"rand 0.8.5",
"sha1",
"simdutf8",
"thiserror 1.0.69",
"tokio",
"utf-8",
]

[[package]]
name = "fdeflate"
version = "0.3.7"
@@ -4272,7 +4483,7 @@ dependencies = [
"cudarc",
"half",
"num-traits",
"rand 0.9.1",
"rand 0.9.2",
"rand_distr",
]

@@ -5052,7 +5263,7 @@ dependencies = [
"cfg-if 1.0.0",
"crunchy",
"num-traits",
"rand 0.9.1",
"rand 0.9.2",
"rand_distr",
]

@@ -5392,7 +5603,7 @@ dependencies = [
"rustls 0.23.25",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.2",
"tower-service",
"webpki-roots 0.26.8",
]
@@ -6575,6 +6786,12 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"

[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"

[[package]]
name = "matrixmultiply"
version = "0.3.9"
@@ -6858,7 +7075,7 @@ dependencies = [
"image",
"indexmap 2.8.0",
"mistralrs-core",
"rand 0.9.1",
"rand 0.9.2",
"reqwest",
"serde",
"serde_json",
@@ -6910,7 +7127,7 @@ dependencies = [
"objc",
"once_cell",
"radix_trie",
"rand 0.9.1",
"rand 0.9.2",
"rand_isaac",
"rayon",
"regex",
@@ -6931,7 +7148,7 @@ dependencies = [
"tokio",
"tokio-rayon",
"toktrie_hf_tokenizers",
"toml",
"toml 0.8.20",
"tqdm",
"tracing",
"tracing-subscriber",
@@ -7893,6 +8110,12 @@ dependencies = [
"pkg-config",
]

[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"

[[package]]
name = "openssl-probe"
version = "0.1.6"
@@ -8081,7 +8304,7 @@ dependencies = [
"glob",
"opentelemetry 0.29.1",
"percent-encoding",
"rand 0.9.1",
"rand 0.9.2",
"serde_json",
"thiserror 2.0.12",
"tokio",
@@ -8531,6 +8754,34 @@ dependencies = [
"time",
]

[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]

[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"

[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]

[[package]]
name = "ply-rs"
version = "0.1.3"
@@ -9161,7 +9412,7 @@ checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc"
dependencies = [
"bytes",
"getrandom 0.3.2",
"rand 0.9.1",
"rand 0.9.2",
"ring 0.17.14",
"rustc-hash 2.1.1",
"rustls 0.23.25",
@@ -9261,9 +9512,9 @@ dependencies = [

[[package]]
name = "rand"
version = "0.9.1"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
@@ -9314,7 +9565,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
"rand 0.9.2",
]

[[package]]
@@ -9370,7 +9621,7 @@ dependencies = [
"simd_helpers",
"system-deps",
"thiserror 1.0.69",
"toml",
"toml 0.8.20",
"v_frame",
"y4m",
]
@@ -10508,7 +10759,7 @@ dependencies = [
"serde",
"syn 2.0.101",
"tempfile",
"toml",
"toml 0.8.20",
"unindent",
"xshell",
]
@@ -11260,7 +11511,7 @@ dependencies = [
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.2",
"tokio-util",
"tower 0.5.2",
"tower-service",
@@ -11660,7 +11911,7 @@ dependencies = [
"num-derive",
"num-traits",
"paste",
"rand 0.9.1",
"rand 0.9.2",
"serde",
"serde_repr",
"socket2 0.5.8",
@@ -11731,6 +11982,18 @@ dependencies = [
"webpki",
]

[[package]]
name = "rustls"
version = "0.21.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
dependencies = [
"log",
"ring 0.17.14",
"rustls-webpki 0.101.7",
"sct",
]

[[package]]
name = "rustls"
version = "0.23.25"
@@ -11824,6 +12087,26 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"

[[package]]
name = "rustls-webpki"
version = "0.100.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6a5fc258f1c1276dfe3016516945546e2d5383911efc0fc4f1cdc5df3a4ae3"
dependencies = [
"ring 0.16.20",
"untrusted 0.7.1",
]

[[package]]
name = "rustls-webpki"
version = "0.101.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring 0.17.14",
"untrusted 0.9.0",
]

[[package]]
name = "rustls-webpki"
version = "0.102.8"
@@ -12228,9 +12511,9 @@ dependencies = [

[[package]]
name = "serde_json"
version = "1.0.140"
version = "1.0.141"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3"
dependencies = [
"indexmap 2.8.0",
"itoa",
@@ -12239,6 +12522,16 @@ dependencies = [
"serde",
]

[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]

[[package]]
name = "serde_plain"
version = "1.0.2"
@@ -12268,6 +12561,15 @@ dependencies = [
"serde",
]

[[package]]
name = "serde_spanned"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40734c41988f7306bb04f0ecf60ec0f3f1caa34290e4e8ea471dcd3346483b83"
dependencies = [
"serde",
]

[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@@ -13221,7 +13523,7 @@ dependencies = [
"cfg-expr",
"heck 0.5.0",
"pkg-config",
"toml",
"toml 0.8.20",
"version-compare",
]

@@ -13257,6 +13559,12 @@ version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"

[[package]]
name = "target-triple"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790"

[[package]]
name = "tempfile"
version = "3.19.1"
@@ -13515,6 +13823,16 @@ dependencies = [
"zerovec",
]

[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]

[[package]]
name = "tinyvec"
version = "1.9.0"
@@ -13540,7 +13858,7 @@ dependencies = [
"pin-project-lite",
"thiserror 2.0.12",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.2",
]

[[package]]
@@ -13638,6 +13956,16 @@ dependencies = [
"tokio",
]

[[package]]
name = "tokio-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls 0.21.12",
"tokio",
]

[[package]]
name = "tokio-rustls"
version = "0.26.2"
@@ -13720,11 +14048,26 @@ checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148"
dependencies = [
"indexmap 2.8.0",
"serde",
"serde_spanned",
"toml_datetime",
"serde_spanned 0.6.8",
"toml_datetime 0.6.8",
"toml_edit",
]

[[package]]
name = "toml"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41ae868b5a0f67631c14589f7e250c1ea2c574ee5ba21c6c8dd4b1485705a5a1"
dependencies = [
"indexmap 2.8.0",
"serde",
"serde_spanned 1.0.0",
"toml_datetime 0.7.0",
"toml_parser",
"toml_writer",
"winnow",
]

[[package]]
name = "toml_datetime"
version = "0.6.8"
@@ -13734,6 +14077,15 @@ dependencies = [
"serde",
]

[[package]]
name = "toml_datetime"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bade1c3e902f58d73d3f294cd7f20391c1cb2fbcb643b73566bc773971df91e3"
dependencies = [
"serde",
]

[[package]]
name = "toml_edit"
version = "0.22.24"
@@ -13742,11 +14094,26 @@ checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
"indexmap 2.8.0",
"serde",
"serde_spanned",
"toml_datetime",
"serde_spanned 0.6.8",
"toml_datetime 0.6.8",
"winnow",
]

[[package]]
name = "toml_parser"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97200572db069e74c512a14117b296ba0a80a30123fbbb5aa1f4a348f639ca30"
dependencies = [
"winnow",
]

[[package]]
name = "toml_writer"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64"

[[package]]
name = "tonic"
version = "0.12.3"
@@ -13755,7 +14122,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
dependencies = [
"async-stream",
"async-trait",
"axum",
"axum 0.7.9",
"base64 0.22.1",
"bytes",
"h2 0.4.8",
@@ -13772,7 +14139,7 @@ dependencies = [
"rustls-pemfile 2.2.0",
"socket2 0.5.8",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.2",
"tokio-stream",
"tower 0.4.13",
"tower-layer",
@@ -13858,6 +14225,7 @@ dependencies = [
"tokio",
"tower-layer",
"tower-service",
"tracing",
]

[[package]]
@@ -14005,6 +14373,21 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"

[[package]]
name = "trybuild"
version = "1.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65af40ad689f2527aebbd37a0a816aea88ff5f774ceabe99de5be02f2f91dae2"
dependencies = [
"glob",
"serde",
"serde_derive",
"serde_json",
"target-triple",
"termcolor",
"toml 0.9.4",
]

[[package]]
name = "ttf-parser"
version = "0.25.1"
@@ -14352,7 +14735,7 @@ checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
dependencies = [
"getrandom 0.3.2",
"js-sys",
"rand 0.9.1",
"rand 0.9.2",
"serde",
"uuid-macro-internal",
"wasm-bindgen",
@@ -14810,6 +15193,15 @@ dependencies = [
"webpki",
]

[[package]]
name = "webpki-roots"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338"
dependencies = [
"rustls-webpki 0.100.3",
]

[[package]]
name = "webpki-roots"
version = "0.26.8"
@@ -15676,9 +16068,9 @@ dependencies = [

[[package]]
name = "winnow"
version = "0.7.4"
version = "0.7.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36"
checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95"
dependencies = [
"memchr",
]
@@ -15888,6 +16280,12 @@ dependencies = [
"linked-hash-map",
]

[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"

[[package]]
name = "yoke"
version = "0.7.5"
@@ -16549,7 +16947,7 @@ dependencies = [
"time",
"tls-listener",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.2",
"tokio-util",
"tracing",
"webpki-roots 0.26.8",


+ 1
- 0
Cargo.toml View File

@@ -34,6 +34,7 @@ members = [
"node-hub/dora-rerun",
"node-hub/terminal-print",
"node-hub/openai-proxy-server",
"node-hub/dora-openai-websocket",
"node-hub/dora-kit-car",
"node-hub/dora-object-to-pose",
"node-hub/dora-mistral-rs",


+ 29
- 0
node-hub/dora-openai-websocket/Cargo.toml View File

@@ -0,0 +1,29 @@
[package]
name = "dora-openai-websocket"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
dora-node-api = { workspace = true }
dora-cli = { workspace = true }
tokio = { version = "1.25.0", features = ["full", "macros"] }
tokio-rustls = "0.24.0"
rustls-pemfile = "1.0"
hyper-util = { version = "0.1.0", features = ["tokio"] }
http-body-util = { version = "0.1.0" }
hyper = { version = "1", features = ["http1", "server", "client"] }
assert2 = "0.3.4"
trybuild = "1.0.106"
criterion = "0.4.0"
anyhow = "1.0.71"
webpki-roots = "0.23.0"
bytes = "1.4.0"
axum = "0.8.1"
fastwebsockets = { version = "0.10.0", features = ["upgrade"] }
serde_json = "1.0.141"
serde = "1.0.219"
base = "0.1.0"
base64 = "0.22.1"
rand = "0.9.2"

+ 474
- 0
node-hub/dora-openai-websocket/src/main.rs View File

@@ -0,0 +1,474 @@
// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use base64::engine::general_purpose;
use base64::Engine;
use dora_cli::command::Executable;
use dora_cli::command::Start;
use dora_node_api::arrow::array::AsArray;
use dora_node_api::arrow::datatypes::DataType;
use dora_node_api::dora_core::config::DataId;
use dora_node_api::dora_core::config::NodeId;
use dora_node_api::dora_core::topics::DORA_COORDINATOR_PORT_CONTROL_DEFAULT;
use dora_node_api::into_vec;
use dora_node_api::DoraNode;
use dora_node_api::IntoArrow;
use dora_node_api::MetadataParameters;
use fastwebsockets::upgrade;
use fastwebsockets::Frame;
use fastwebsockets::OpCode;
use fastwebsockets::Payload;
use fastwebsockets::WebSocketError;
use http_body_util::Empty;
use hyper::body::Bytes;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::Request;
use hyper::Response;
use rand::random;
use serde;
use serde::Deserialize;
use serde::Serialize;
use std::fs;
use std::io::{self, Write};
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::time::Duration;
use tokio::net::TcpListener;

#[derive(Serialize, Deserialize, Debug)]
pub struct ErrorDetails {
pub code: Option<String>,
pub message: String,
pub param: Option<String>,
#[serde(rename = "type")]
pub error_type: Option<String>,
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum OpenAIRealtimeMessage {
#[serde(rename = "session.update")]
SessionUpdate { session: SessionConfig },
#[serde(rename = "input_audio_buffer.append")]
InputAudioBufferAppend {
audio: String, // base64 encoded audio
},
#[serde(rename = "input_audio_buffer.commit")]
InputAudioBufferCommit,
#[serde(rename = "response.create")]
ResponseCreate { response: ResponseConfig },
#[serde(rename = "conversation.item.create")]
ConversationItemCreate { item: ConversationItem },
#[serde(rename = "conversation.item.truncate")]
ConversationItemTruncate {
item_id: String,
content_index: u32,
audio_end_ms: u32,
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
},
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SessionConfig {
pub modalities: Vec<String>,
pub instructions: String,
pub voice: String,
pub input_audio_format: String,
pub output_audio_format: String,
pub input_audio_transcription: Option<TranscriptionConfig>,
pub turn_detection: Option<TurnDetectionConfig>,
pub tools: Vec<serde_json::Value>,
pub tool_choice: String,
pub temperature: f32,
pub max_response_output_tokens: Option<u32>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct TranscriptionConfig {
pub model: String,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct TurnDetectionConfig {
#[serde(rename = "type")]
pub detection_type: String,
pub threshold: f32,
pub prefix_padding_ms: u32,
pub silence_duration_ms: u32,
pub interrupt_response: bool,
pub create_response: bool,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseConfig {
pub modalities: Vec<String>,
pub instructions: Option<String>,
pub voice: Option<String>,
pub output_audio_format: Option<String>,
pub tools: Option<Vec<serde_json::Value>>,
pub tool_choice: Option<String>,
pub temperature: Option<f32>,
pub max_output_tokens: Option<u32>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ConversationItem {
pub id: Option<String>,
#[serde(rename = "type")]
pub item_type: String,
pub status: Option<String>,
pub role: String,
pub content: Vec<ContentPart>,
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "input_text")]
InputText { text: String },
#[serde(rename = "input_audio")]
InputAudio {
audio: String,
transcript: Option<String>,
},
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "audio")]
Audio {
audio: String,
transcript: Option<String>,
},
}

// Incoming message types from OpenAI
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum OpenAIRealtimeResponse {
#[serde(rename = "error")]
Error { error: ErrorDetails },
#[serde(rename = "session.created")]
SessionCreated { session: serde_json::Value },
#[serde(rename = "session.updated")]
SessionUpdated { session: serde_json::Value },
#[serde(rename = "conversation.item.created")]
ConversationItemCreated { item: serde_json::Value },
#[serde(rename = "conversation.item.truncated")]
ConversationItemTruncated { item: serde_json::Value },
#[serde(rename = "response.audio.delta")]
ResponseAudioDelta {
response_id: String,
item_id: String,
output_index: u32,
content_index: u32,
delta: String, // base64 encoded audio
},
#[serde(rename = "response.audio.done")]
ResponseAudioDone {
response_id: String,
item_id: String,
output_index: u32,
content_index: u32,
},
#[serde(rename = "response.text.delta")]
ResponseTextDelta {
response_id: String,
item_id: String,
output_index: u32,
content_index: u32,
delta: String,
},
#[serde(rename = "response.audio_transcript.delta")]
ResponseAudioTranscriptDelta {
response_id: String,
item_id: String,
output_index: u32,
content_index: u32,
delta: String,
},
#[serde(rename = "response.done")]
ResponseDone { response: serde_json::Value },
#[serde(rename = "input_audio_buffer.speech_started")]
InputAudioBufferSpeechStarted {
audio_start_ms: u32,
item_id: String,
},
#[serde(rename = "input_audio_buffer.speech_stopped")]
InputAudioBufferSpeechStopped { audio_end_ms: u32, item_id: String },
#[serde(other)]
Other,
}

fn convert_pcm16_to_f32(bytes: &[u8]) -> Vec<f32> {
let mut samples = Vec::with_capacity(bytes.len() / 2);

for chunk in bytes.chunks_exact(2) {
let pcm16_sample = i16::from_le_bytes([chunk[0], chunk[1]]);
let f32_sample = pcm16_sample as f32 / 32767.0;
samples.push(f32_sample);
}

samples
}

fn convert_f32_to_pcm16(samples: &[f32]) -> Vec<u8> {
let mut pcm16_bytes = Vec::with_capacity(samples.len() * 2);

for &sample in samples {
// Clamp to [-1.0, 1.0] and convert to i16
let clamped = sample.max(-1.0).min(1.0);
let pcm16_sample = (clamped * 32767.0) as i16;
pcm16_bytes.extend_from_slice(&pcm16_sample.to_le_bytes());
}

pcm16_bytes
}

/// Replaces a placeholder in a file and writes the result to an output file.
///
/// # Arguments
///
/// * `input_path` - Path to the input file with placeholder text.
/// * `placeholder` - The placeholder text to search for (e.g., "{{PLACEHOLDER}}").
/// * `replacement` - The text to replace the placeholder with.
/// * `output_path` - Path to write the modified content.
fn replace_placeholder_in_file(
input_path: &str,
placeholder: &str,
replacement: &str,
output_path: &str,
) -> io::Result<()> {
// Read the file content into a string
let content = fs::read_to_string(input_path)?;

// Replace the placeholder
let modified_content = content.replace(placeholder, replacement);

// Write the modified content to the output file
let mut file = fs::File::create(output_path)?;
file.write_all(modified_content.as_bytes())?;

Ok(())
}

async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> {
let mut ws = fastwebsockets::FragmentCollector::new(fut.await?);

let frame = ws.read_frame().await?;
if frame.opcode != OpCode::Text {
return Err(WebSocketError::InvalidConnectionHeader);
}
let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap();
let OpenAIRealtimeMessage::SessionUpdate { session } = data else {
return Err(WebSocketError::InvalidConnectionHeader);
};

let input_audio_transcription = session
.input_audio_transcription
.map_or("moyoyo-whisper".to_string(), |t| t.model);
let id = random::<u16>();
let node_id = format!("server-{id}");
let dataflow = format!("{input_audio_transcription}-{}.yml", id);
let template = format!("{input_audio_transcription}-template-metal.yml");
println!("Filling template: {}", template);
replace_placeholder_in_file(&template, "NODE_ID", &node_id, &dataflow).unwrap();
// Copy configuration file but replace the node ID with "server-id"
// Read the configuration file and replace the node ID with "server-id"
dora_cli::command::Command::Start(Start {
dataflow,
name: Some(node_id.to_string()),
coordinator_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
coordinator_port: DORA_COORDINATOR_PORT_CONTROL_DEFAULT,
attach: false,
detach: true,
hot_reload: false,
uv: true,
})
.execute()
.unwrap();
let (mut node, mut events) =
DoraNode::init_from_node_id(NodeId::from(node_id.clone())).unwrap();
let serialized_data = OpenAIRealtimeResponse::SessionCreated {
session: serde_json::Value::Null,
};

let payload =
Payload::Bytes(Bytes::from(serde_json::to_string(&serialized_data).unwrap()).into());
let frame = Frame::text(payload);
ws.write_frame(frame).await?;
loop {
let mut frame = ws.read_frame().await?;
let mut finished = false;
match frame.opcode {
OpCode::Close => break,
OpCode::Text | OpCode::Binary => {
let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap();

match data {
OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => {
// println!("Received audio data: {}", audio);
let f32_data = audio;
// Decode base64 encoded audio data
let f32_data = f32_data.trim();
if f32_data.is_empty() {
continue;
}

if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) {
let f32_data = convert_pcm16_to_f32(&f32_data);
// Downsample to 16 kHz from 24 kHz
let f32_data = f32_data
.into_iter()
.enumerate()
.filter(|(i, _)| i % 3 != 0)
.map(|(_, v)| v)
.collect::<Vec<f32>>();
let mut parameter = MetadataParameters::default();
parameter.insert(
"sample_rate".to_string(),
dora_node_api::Parameter::Integer(16000),
);
node.send_output(
DataId::from("audio".to_string()),
parameter,
f32_data.into_arrow(),
)
.unwrap();
let ev = events.recv_async_timeout(Duration::from_millis(10)).await;

// println!("Received event: {:?}", ev);
let frame = match ev {
Some(dora_node_api::Event::Input {
id,
metadata: _,
data,
}) => {
if data.data_type() == &DataType::Utf8 {
let data = data.as_string::<i32>();
let str = data.value(0);
let serialized_data =
OpenAIRealtimeResponse::ResponseAudioTranscriptDelta {
response_id: "123".to_string(),
item_id: "123".to_string(),
output_index: 123,
content_index: 123,
delta: str.to_string(),
};

frame.payload = Payload::Bytes(
Bytes::from(
serde_json::to_string(&serialized_data).unwrap(),
)
.into(),
);
frame.opcode = OpCode::Text;
frame
} else if id.contains("audio") {
let data: Vec<f32> = into_vec(&data).unwrap();
let data = convert_f32_to_pcm16(&data);
let serialized_data =
OpenAIRealtimeResponse::ResponseAudioDelta {
response_id: "123".to_string(),
item_id: "123".to_string(),
output_index: 123,
content_index: 123,
delta: general_purpose::STANDARD.encode(data),
};
finished = true;

frame.payload = Payload::Bytes(
Bytes::from(
serde_json::to_string(&serialized_data).unwrap(),
)
.into(),
);
frame.opcode = OpCode::Text;
frame
} else {
unimplemented!()
}
}
Some(dora_node_api::Event::Error(_)) => {
// println!("Error in input: {}", s);
continue;
}
_ => break,
};
ws.write_frame(frame).await?;
if finished {
let serialized_data = OpenAIRealtimeResponse::ResponseDone {
response: serde_json::Value::Null,
};

let payload = Payload::Bytes(
Bytes::from(serde_json::to_string(&serialized_data).unwrap())
.into(),
);
println!("Sending response done: {:?}", serialized_data);
let frame = Frame::text(payload);
ws.write_frame(frame).await?;
}
}
}
OpenAIRealtimeMessage::InputAudioBufferCommit => break,
_ => {}
}
}
_ => break,
}
}

Ok(())
}
async fn server_upgrade(
mut req: Request<Incoming>,
) -> Result<Response<Empty<Bytes>>, WebSocketError> {
let (response, fut) = upgrade::upgrade(&mut req)?;

tokio::task::spawn(async move {
if let Err(e) = tokio::task::unconstrained(handle_client(fut)).await {
eprintln!("Error in websocket connection: {}", e);
}
});

Ok(response)
}

fn main() -> Result<(), WebSocketError> {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.build()
.unwrap();

rt.block_on(async move {
let listener = TcpListener::bind("127.0.0.1:8123").await?;
println!("Server started, listening on {}", "127.0.0.1:8123");
loop {
let (stream, _) = listener.accept().await?;
println!("Client connected");
tokio::spawn(async move {
let io = hyper_util::rt::TokioIo::new(stream);
let conn_fut = http1::Builder::new()
.serve_connection(io, service_fn(server_upgrade))
.with_upgrades();
if let Err(e) = conn_fut.await {
println!("An error occurred: {:?}", e);
}
});
}
})
}

Loading…
Cancel
Save