| @@ -30,6 +30,8 @@ members = [ | |||
| "libraries/shared-memory-server", | |||
| "libraries/extensions/download", | |||
| "libraries/extensions/telemetry/*", | |||
| "node-hub/dora-mcp-host", | |||
| "node-hub/dora-mcp-server", | |||
| "node-hub/dora-record", | |||
| "node-hub/dora-rerun", | |||
| "node-hub/terminal-print", | |||
| @@ -48,7 +50,7 @@ members = [ | |||
| ] | |||
| [workspace.package] | |||
| edition = "2024" | |||
| edition = "2021" | |||
| rust-version = "1.85.0" | |||
| # Make sure to also bump `apis/node/python/__init__.py` version. | |||
| version = "0.3.12" | |||
| @@ -104,7 +106,6 @@ edition.workspace = true | |||
| license = "Apache-2.0" | |||
| publish = false | |||
| [features] | |||
| # enables examples that depend on a sourced ROS2 installation | |||
| ros2-examples = [] | |||
| @@ -0,0 +1,18 @@ | |||
| # Dora Openai MCP Host Example | |||
| This is a quick example to showcase how use the `dora-openai-server` to receive and send back data. | |||
| Dora Openai Server is still experimental and may change in the future. | |||
| Make sure to have, dora, uv and cargo installed. | |||
| ```bash | |||
| uv venv -p 3.11 --seed | |||
| uv pip install -e ../../apis/python/node --reinstall | |||
| dora build dataflow.yml --uv | |||
| dora run dataflow.yml --uv | |||
| # In a separate terminal | |||
| uv run test_client.py | |||
| dora stop | |||
| ``` | |||
| @@ -0,0 +1,44 @@ | |||
| nodes: | |||
| - id: mcp-server | |||
| build: cargo build -p dora-mcp-server --release | |||
| path: ../../target/release/dora-mcp-server | |||
| outputs: | |||
| - local | |||
| - telepathy | |||
| inputs: | |||
| local_reply: local/reply | |||
| telepathy_reply: telepathy/reply | |||
| env: | |||
| CONFIG: mcp_server.toml | |||
| - id: local | |||
| path: nodes/local.py | |||
| inputs: | |||
| text: mcp-server/local | |||
| outputs: | |||
| - reply | |||
| - id: telepathy | |||
| path: nodes/telepathy.py | |||
| inputs: | |||
| text: mcp-server/telepathy | |||
| outputs: | |||
| - reply | |||
| - id: dora-echo | |||
| build: pip install -e ../../node-hub/dora-echo | |||
| path: dora-echo | |||
| inputs: | |||
| echo: dora-mcp-host/text | |||
| outputs: | |||
| - echo | |||
| - id: dora-mcp-host | |||
| build: cargo build -p dora-mcp-host --release | |||
| path: ../../target/release/dora-mcp-host | |||
| outputs: | |||
| - text | |||
| inputs: | |||
| text: dora-echo/echo | |||
| env: | |||
| CONFIG: mcp_host.toml | |||
| @@ -0,0 +1,6 @@ | |||
| { | |||
| "$schema": "http://json-schema.org/draft-07/schema#", | |||
| "type": "object", | |||
| "properties": { | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| { | |||
| "$schema": "http://json-schema.org/draft-07/schema#", | |||
| "type": "object", | |||
| "properties": { | |||
| "location": { | |||
| "type": "string", | |||
| "description": "location" | |||
| } | |||
| }, | |||
| "required": [ | |||
| "location" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,51 @@ | |||
| listen_addr = "0.0.0.0:8118" | |||
| endpoint = "v1" | |||
| [[providers]] | |||
| id = "moonshot" | |||
| kind ="openai" | |||
| api_key = "env:MOONSHOT_API_KEY" | |||
| api_url = "https://api.moonshot.cn/v1" | |||
| [[providers]] | |||
| id = "gemini" | |||
| kind ="gemini" | |||
| # api_key = "env:GEMINI_API_KEY" | |||
| api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" | |||
| # [[providers]] | |||
| # id = "deepseek" | |||
| # kind ="deepseek" | |||
| # # api_key = "env:DEEPSEEK_API_KEY" | |||
| # api_url = "https://api.deepseek.com" | |||
| # [[providers]] | |||
| # id = "openai" | |||
| # # api_key = "env:OPENAI_API_KEY" | |||
| # api_url = "https://api.openai.com/v1" | |||
| [[providers]] | |||
| id = "dora" | |||
| kind ="dora" | |||
| output = "output" | |||
| [[models]] | |||
| id = "kimi-latest" | |||
| # default = true | |||
| route = { provider = "moonshot", model = "kimi-latest" } | |||
| [[models]] | |||
| id = "gemini-2.0-flash" | |||
| route = { provider = "gemini", model = "gemini-2.0-flash" } | |||
| [[mcp.servers]] | |||
| name = "amap-maps" | |||
| protocol = "stdio" | |||
| command = "npx" | |||
| args = ["-y", "@amap/amap-maps-mcp-server"] | |||
| envs = {AMAP_MAPS_API_KEY = "your_amap_maps_api_key_here"} | |||
| [[mcp.servers]] | |||
| name = "local" | |||
| protocol = "streamable" | |||
| url = "http://127.0.0.1:8228/mcp" | |||
| @@ -0,0 +1,28 @@ | |||
| name = "MCP Server Example" | |||
| version = "0.1.0" | |||
| # You can set your custom listen address and endpoint here. | |||
| # Default listen address is "0.0.0.0:8008" and endpoint is "mcp". | |||
| listen_addr = "0.0.0.0:8228" | |||
| endpoint = "mcp" | |||
| [[mcp_tools]] | |||
| name = "signature_dish" | |||
| description = "Tell you the name of the most signature dish in a certain restaurant." | |||
| args = [] | |||
| input_schema = "restaurant_object.json" | |||
| output = "local" | |||
| [[mcp_tools]] | |||
| name = "happiest_kindergarten" | |||
| description = "Tell you the happiest kindergarten in this location and how many children are in this kindergarten." | |||
| args = [] | |||
| input_schema = "local_object.json" | |||
| output = "local" | |||
| [[mcp_tools]] | |||
| name = "telepathy" | |||
| description = "Know who the current user's favorite star is and what their favorite movie is." | |||
| args = [] | |||
| input_schema = "empty_object.json" | |||
| output = "telepathy" | |||
| @@ -0,0 +1,52 @@ | |||
| """ | |||
| This is just a simple demonstration of an MCP server. | |||
| The example returns some local information about the user's request, such as the tallest building, | |||
| the happiest kindergarten, the best restaurant, etc. | |||
| """ | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| import json | |||
| import random | |||
| signature_dishes = [ | |||
| "Kung Pao Chicken", | |||
| "Mapo Tofu", | |||
| "Twice Cooked Pork", | |||
| "Sweet and Sour Pork", | |||
| "Boiled Fish in Chili Oil", | |||
| "Peking Duck", | |||
| "Xiaolongbao", | |||
| "Red Braised Pork", | |||
| "Fish-Flavored Shredded Pork", | |||
| "Dongpo Pork", | |||
| "White Cut Chicken", | |||
| "Steamed Egg Custard", | |||
| "Fish with Pickled Cabbage", | |||
| "Saliva Chicken", | |||
| "Spicy Beef and Ox Tongue", | |||
| "Laziji (Spicy Diced Chicken)", | |||
| "Steamed Sea Bass", | |||
| "Ants Climbing a Tree", | |||
| "Beggar's Chicken", | |||
| "Buddha Jumps Over the Wall" | |||
| ] | |||
| node = Node() | |||
| for event in node: | |||
| if event["type"] == "INPUT": | |||
| if 'metadata' in event: | |||
| data = json.loads(event["value"][0].as_py()) | |||
| name = data.get("name", "") | |||
| location = data.get("arguments", {}).get("location", "") | |||
| match name: | |||
| case "signature_dish": | |||
| random_dish = random.choice(signature_dishes) | |||
| node.send_output("reply", pa.array([f'{{"content":[{{"type": "text", "text": "{{\\"signature_dish\\": \\"{random_dish}\\"}}"}}]}}']), metadata=event["metadata"]) | |||
| case "happiest_kindergarten": | |||
| node.send_output("reply", pa.array([f'{{"content":[{{"type": "text", "text": "{{\\"kindergarten\\":\\"Golden Sun Kindergarten\\", \\"children\\": 300}}"}}]}}']), metadata=event["metadata"]) | |||
| case _: | |||
| print(f"Unknown command: {name}") | |||
| @@ -0,0 +1,46 @@ | |||
| """ | |||
| This is just a simple demonstration of an MCP server. | |||
| This MCP server has the ability of telepathy and can know who the current | |||
| user's favorite star is and what their favorite movie is. | |||
| """ | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| import json | |||
| import random | |||
| star_movie_pairs = [ | |||
| {"star": "Tom Hanks", "movie": "Forrest Gump"}, | |||
| {"star": "Leonardo DiCaprio", "movie": "Titanic"}, | |||
| {"star": "Will Smith", "movie": "Men in Black"}, | |||
| {"star": "Robert Downey Jr.", "movie": "Iron Man"}, | |||
| {"star": "Johnny Depp", "movie": "Pirates of the Caribbean"}, | |||
| {"star": "Brad Pitt", "movie": "Fight Club"}, | |||
| {"star": "Angelina Jolie", "movie": "Maleficent"}, | |||
| {"star": "Scarlett Johansson", "movie": "Black Widow"}, | |||
| {"star": "Chris Evans", "movie": "Captain America"}, | |||
| {"star": "Ryan Reynolds", "movie": "Deadpool"}, | |||
| {"star": "Emma Stone", "movie": "La La Land"}, | |||
| {"star": "Jennifer Lawrence", "movie": "The Hunger Games"}, | |||
| {"star": "Morgan Freeman", "movie": "The Shawshank Redemption"}, | |||
| {"star": "Denzel Washington", "movie": "Training Day"}, | |||
| {"star": "Matt Damon", "movie": "The Martian"}, | |||
| ] | |||
| node = Node() | |||
| for event in node: | |||
| if event["type"] == "INPUT": | |||
| if 'metadata' in event: | |||
| data = json.loads(event["value"][0].as_py()) | |||
| name = data.get("name", "") | |||
| location = data.get("arguments", {}).get("location", "") | |||
| match name: | |||
| case "telepathy": | |||
| random_pair = random.choice(star_movie_pairs) | |||
| star = random_pair["star"] | |||
| movie = random_pair["movie"] | |||
| node.send_output("reply", pa.array([f'{{"content":[{{"type": "text", "text": "{{\\"star\\":\\"{star}\\", \\"movie\\":\\"{movie}\\"}}"}}]}}']), metadata=event["metadata"]) | |||
| case _: | |||
| print(f"Unknown command: {name}") | |||
| @@ -0,0 +1,13 @@ | |||
| { | |||
| "$schema": "http://json-schema.org/draft-07/schema#", | |||
| "type": "object", | |||
| "properties": { | |||
| "restaurant": { | |||
| "type": "string", | |||
| "description": "restaurant name" | |||
| } | |||
| }, | |||
| "required": [ | |||
| "restaurant" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,102 @@ | |||
| """TODO: Add docstring.""" | |||
| from openai import OpenAI | |||
| import httpx | |||
| transport = httpx.HTTPTransport(proxy=None) | |||
| http_client = httpx.Client(transport=transport) | |||
| client = OpenAI(base_url="http://127.0.0.1:8118/v1", api_key="dummy_api_key", http_client=http_client) | |||
| def test_list_models(): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| models = client.models.list() | |||
| print("Available models:") | |||
| for model in models.data: | |||
| print(f"- {model.id}") | |||
| except Exception as e: | |||
| print(f"Error listing models: {e}") | |||
| def test_chat_completion(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="kimi-latest", | |||
| messages=[ | |||
| {"role": "system", "content": "You are a helpful assistant."}, | |||
| {"role": "user", "content": user_input}, | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| def test_chat_completion_image_url(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="kimi-latest", | |||
| messages=[ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": "What is in this image?"}, | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" | |||
| }, | |||
| }, | |||
| ], | |||
| } | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| def test_chat_completion_image_base64(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="gpt-3.5-turbo", | |||
| messages=[ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": "What is in this image?"}, | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": "" | |||
| }, | |||
| }, | |||
| ], | |||
| } | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| if __name__ == "__main__": | |||
| print("Testing API endpoints...") | |||
| # test_list_models() | |||
| print("\n" + "=" * 50 + "\n") | |||
| while True: | |||
| chat_input = input("Enter a message for chat completion: ") | |||
| test_chat_completion(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| # test_chat_completion_image_url(chat_input) | |||
| # print("\n" + "=" * 50 + "\n") | |||
| # test_chat_completion_image_base64(chat_input) | |||
| # print("\n" + "=" * 50 + "\n") | |||
| @@ -0,0 +1,20 @@ | |||
| # Dora MCP Server Example | |||
| This is a quick example to showcase how use the `dora-mcp-server` to receive and send back data. | |||
| Dora MCP Server is still experimental and may change in the future. | |||
| Make sure to have, dora, uv and cargo installed. | |||
| ```bash | |||
| uv venv -p 3.11 --seed | |||
| uv pip install -e ../../apis/python/node --reinstall | |||
| dora build dataflow.yml --uv | |||
| dora run dataflow.yml --uv | |||
| ``` | |||
| You can use mpc inspector to test: | |||
| ```bash | |||
| npx @modelcontextprotocol/inspector | |||
| ``` | |||
| @@ -0,0 +1,32 @@ | |||
| name = "MCP Server Example" | |||
| version = "0.1.0" | |||
| # You can set your custom listen address and endpoint here. | |||
| # Default listen address is "0.0.0.0:8008" and endpoint is "mcp". | |||
| # In this example, the final service url is: http://0.0.0.0:8181/mcp | |||
| listen_addr = "0.0.0.0:8181" | |||
| endpoint = "mcp" | |||
| [[mcp_tools]] | |||
| name = "counter_decrement" # (Required) type: String, Unique identifier for the tool | |||
| title = "Decrement Counter" # (Optional) type: String, Human-readable name of the tool for display purposes | |||
| input_schema = "empty_object.json" # (Required) JSON Schema defining expected parameters | |||
| output = "counter" # (Required) type: String, Set the output name | |||
| [mcp_tools.annotations] # (Optional) Additional properties describing a Tool to clients | |||
| title = "decrement current value of the counter" # type: String, A human-readable title for the tool | |||
| [[mcp_tools]] | |||
| name = "counter_increment" | |||
| title = "Increment Counter" | |||
| input_schema = "empty_object.json" | |||
| output = "counter" | |||
| [mcp_tools.annotations] | |||
| title = "Increment current value of the counter" | |||
| [[mcp_tools]] | |||
| name = "counter_get_value" | |||
| title = "Get Counter Value" | |||
| input_schema = "empty_object.json" | |||
| output = "counter" | |||
| [mcp_tools.annotations] | |||
| title = "Get the current value of the counter" | |||
| @@ -0,0 +1,17 @@ | |||
| nodes: | |||
| - id: mcp-server | |||
| build: cargo build -p dora-mcp-server --release | |||
| path: ../../target/release/dora-mcp-server | |||
| outputs: | |||
| - counter | |||
| inputs: | |||
| counter_reply: counter/reply | |||
| env: | |||
| CONFIG: config.toml | |||
| - id: counter | |||
| path: nodes/counter.py | |||
| inputs: | |||
| text: mcp-server/counter | |||
| outputs: | |||
| - reply | |||
| @@ -0,0 +1,6 @@ | |||
| { | |||
| "$schema": "http://json-schema.org/draft-07/schema#", | |||
| "type": "object", | |||
| "title": "EmptyObject", | |||
| "description": "Input parameters for the counter tool" | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| """TODO: Add docstring.""" | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| import json | |||
| node = Node() | |||
| count = 0 | |||
| for event in node: | |||
| if event["type"] == "INPUT": | |||
| if 'metadata' in event: | |||
| data = json.loads(event["value"][0].as_py()) | |||
| name = data.get("name", "") | |||
| match name: | |||
| case "counter_increment": | |||
| count += 1 | |||
| node.send_output("reply", pa.array([f'{{"content":[{{"type": "text", "text": "{count}"}}]}}']), metadata=event["metadata"]) | |||
| case "counter_decrement": | |||
| count -= 1 | |||
| node.send_output("reply", pa.array([f'{{"content":[{{"type": "text", "text": "{count}"}}]}}']), metadata=event["metadata"]) | |||
| case "counter_get_value": | |||
| node.send_output("reply", pa.array([f'{{"content":[{{"type": "text", "text": "{count}"}}]}}']), metadata=event["metadata"]) | |||
| case _: | |||
| print(f"Unknown command: {name}") | |||
| @@ -0,0 +1,44 @@ | |||
| [package] | |||
| name = "dora-mcp-host" | |||
| version.workspace = true | |||
| edition.workspace = true | |||
| rust-version.workspace = true | |||
| documentation.workspace = true | |||
| description.workspace = true | |||
| license.workspace = true | |||
| repository.workspace = true | |||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |||
| [dependencies] | |||
| chrono = "0.4.31" | |||
| dora-node-api = { workspace = true, features = ["tracing"] } | |||
| eyre = "0.6.8" | |||
| figment = { version = "0.10.0", features = ["env", "json", "toml", "yaml"] } | |||
| futures = "0.3.31" | |||
| indexmap = { version = "2.6.0", features = ["serde"] } | |||
| mime_guess = "2.0.4" | |||
| outfox-openai = { version = "0.1.0", git = "https://github.com/outfox-ai/outfox.git" } | |||
| reqwest = { version = "0.12.22" } | |||
| rmcp = { version = "0.3.2", git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "fbc7ab7", features = [ | |||
| "client", | |||
| "transport-child-process", | |||
| "transport-sse-client", | |||
| "transport-streamable-http-client", | |||
| "reqwest", | |||
| ] } | |||
| salvo = { version = "0.81.0", default-features = false, features = [ | |||
| "affix-state", | |||
| "cors", | |||
| "server", | |||
| "http1", | |||
| "http2", | |||
| ] } | |||
| serde = { version = "1.0.130", features = ["derive"] } | |||
| serde_json = "1.0.68" | |||
| thiserror = "2.0.12" | |||
| tokio = { version = "1.46.1", features = ["full"] } | |||
| tokio-stream = "0.1.11" | |||
| tracing = "0.1.27" | |||
| url = "2.2.2" | |||
| uuid = { version = "1.10", features = ["v4"] } | |||
| @@ -0,0 +1,212 @@ | |||
| use tokio::sync::mpsc; | |||
| use eyre::{eyre, Result}; | |||
| use futures::channel::oneshot; | |||
| use outfox_openai::spec::{CreateChatCompletionRequest, CreateChatCompletionResponse}; | |||
| use reqwest::Client as HttpClient; | |||
| use salvo::async_trait; | |||
| use crate::config::{DeepseekConfig, DoraConfig, GeminiConfig, OpenaiConfig}; | |||
| use crate::{utils::get_env_or_value, ServerEvent}; | |||
| #[async_trait] | |||
| pub trait ChatClient: Send + Sync { | |||
| async fn complete( | |||
| &self, | |||
| request: CreateChatCompletionRequest, | |||
| ) -> Result<CreateChatCompletionResponse>; | |||
| } | |||
| #[derive(Debug)] | |||
| pub struct GeminiClient { | |||
| api_key: String, | |||
| api_url: String, | |||
| client: HttpClient, | |||
| } | |||
| impl GeminiClient { | |||
| pub fn new(config: &GeminiConfig) -> Self { | |||
| let client = if config.proxy { | |||
| HttpClient::new() | |||
| } else { | |||
| HttpClient::builder() | |||
| .no_proxy() | |||
| .build() | |||
| .unwrap_or_else(|_| HttpClient::new()) | |||
| }; | |||
| Self { | |||
| api_key: get_env_or_value(&config.api_key), | |||
| api_url: get_env_or_value(&config.api_url), | |||
| client, | |||
| } | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl ChatClient for GeminiClient { | |||
| async fn complete( | |||
| &self, | |||
| request: CreateChatCompletionRequest, | |||
| ) -> Result<CreateChatCompletionResponse> { | |||
| let response = self | |||
| .client | |||
| .post(&self.api_url) | |||
| .header("X-goog-api-key", self.api_key.clone()) | |||
| .header("Content-Type", "application/json") | |||
| .json(&request) | |||
| .send() | |||
| .await?; | |||
| if !response.status().is_success() { | |||
| let error_text = response.text().await?; | |||
| return Err(eyre!("API Error: {}", error_text)); | |||
| } | |||
| let text_data = response.text().await?; | |||
| println!("Received response: {}", text_data); | |||
| let completion: CreateChatCompletionResponse = serde_json::from_str(&text_data) | |||
| .map_err(eyre::Report::from) | |||
| .unwrap(); | |||
| Ok(completion) | |||
| } | |||
| } | |||
| #[derive(Debug)] | |||
| pub struct DeepseekClient { | |||
| api_key: String, | |||
| api_url: String, | |||
| client: HttpClient, | |||
| } | |||
| impl DeepseekClient { | |||
| pub fn new(config: &DeepseekConfig) -> Self { | |||
| let client = if config.proxy { | |||
| HttpClient::new() | |||
| } else { | |||
| HttpClient::builder() | |||
| .no_proxy() | |||
| .build() | |||
| .unwrap_or_else(|_| HttpClient::new()) | |||
| }; | |||
| Self { | |||
| api_key: get_env_or_value(&config.api_key), | |||
| api_url: get_env_or_value(&config.api_url), | |||
| client, | |||
| } | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl ChatClient for DeepseekClient { | |||
| async fn complete( | |||
| &self, | |||
| request: CreateChatCompletionRequest, | |||
| ) -> Result<CreateChatCompletionResponse> { | |||
| let response = self | |||
| .client | |||
| .post(format!("{}/chat/completions", self.api_url)) | |||
| .header("Authorization", format!("Bearer {}", self.api_key)) | |||
| .header("Content-Type", "application/json") | |||
| .json(&request) | |||
| .send() | |||
| .await?; | |||
| if !response.status().is_success() { | |||
| let error_text = response.text().await?; | |||
| return Err(eyre!("API Error: {}", error_text)); | |||
| } | |||
| let text_data = response.text().await?; | |||
| println!("Received response: {}", text_data); | |||
| let completion: CreateChatCompletionResponse = | |||
| serde_json::from_str(&text_data).map_err(eyre::Report::from)?; | |||
| Ok(completion) | |||
| } | |||
| } | |||
| #[derive(Debug)] | |||
| pub struct OpenaiClient { | |||
| api_key: String, | |||
| api_url: String, | |||
| client: HttpClient, | |||
| } | |||
| impl OpenaiClient { | |||
| pub fn new(config: &OpenaiConfig) -> Self { | |||
| let client = if config.proxy { | |||
| HttpClient::new() | |||
| } else { | |||
| HttpClient::builder() | |||
| .no_proxy() | |||
| .build() | |||
| .unwrap_or_else(|_| HttpClient::new()) | |||
| }; | |||
| Self { | |||
| api_key: get_env_or_value(&config.api_key), | |||
| api_url: get_env_or_value(&config.api_url), | |||
| client, | |||
| } | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl ChatClient for OpenaiClient { | |||
| async fn complete( | |||
| &self, | |||
| request: CreateChatCompletionRequest, | |||
| ) -> Result<CreateChatCompletionResponse> { | |||
| let response = self | |||
| .client | |||
| .post(format!("{}/chat/completions", self.api_url)) | |||
| .header("Authorization", format!("Bearer {}", self.api_key)) | |||
| .header("Content-Type", "application/json") | |||
| .json(&request) | |||
| .send() | |||
| .await?; | |||
| if !response.status().is_success() { | |||
| let error_text = response.text().await?; | |||
| return Err(eyre!("API Error: {}", error_text)); | |||
| } | |||
| let text_data = response.text().await?; | |||
| println!("Received response: {}", text_data); | |||
| let completion: CreateChatCompletionResponse = | |||
| serde_json::from_str(&text_data).map_err(eyre::Report::from)?; | |||
| Ok(completion) | |||
| } | |||
| } | |||
| #[derive(Debug)] | |||
| pub struct DoraClient { | |||
| output: String, | |||
| event_sender: mpsc::Sender<ServerEvent>, | |||
| } | |||
| impl DoraClient { | |||
| pub fn new(config: &DoraConfig, event_sender: mpsc::Sender<ServerEvent>) -> Self { | |||
| Self { | |||
| output: config.output.clone(), | |||
| event_sender, | |||
| } | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl ChatClient for DoraClient { | |||
| async fn complete( | |||
| &self, | |||
| request: CreateChatCompletionRequest, | |||
| ) -> Result<CreateChatCompletionResponse> { | |||
| let (tx, rx) = oneshot::channel(); | |||
| self.event_sender | |||
| .send(ServerEvent::CallNode { | |||
| output: self.output.clone(), | |||
| request, | |||
| reply: tx, | |||
| }) | |||
| .await?; | |||
| rx.await | |||
| .map_err(|e| eyre::eyre!("Failed to parse call tool result: {e}")) | |||
| } | |||
| } | |||
| @@ -0,0 +1,313 @@ | |||
| use std::{ | |||
| collections::HashMap, | |||
| path::PathBuf, | |||
| process::Stdio, | |||
| sync::{Arc, OnceLock}, | |||
| }; | |||
| use tokio::sync::mpsc; | |||
| use figment::providers::{Env, Format, Json, Toml, Yaml}; | |||
| use figment::Figment; | |||
| use rmcp::{service::RunningService, transport::ConfigureCommandExt, RoleClient, ServiceExt}; | |||
| use serde::{Deserialize, Serialize}; | |||
| use crate::client::{ChatClient, DeepseekClient, DoraClient, GeminiClient, OpenaiClient}; | |||
| use crate::{ChatSession, ServerEvent, tool::ToolSet}; | |||
| pub static CONFIG: OnceLock<Config> = OnceLock::new(); | |||
| pub fn init() { | |||
| let config_file = Env::var("CONFIG").unwrap_or("config.toml".into()); | |||
| let config_path = PathBuf::from(config_file); | |||
| if !config_path.exists() { | |||
| eprintln!("Config file not found at: {}", config_path.display()); | |||
| std::process::exit(1); | |||
| } | |||
| let raw_config = match config_path | |||
| .extension() | |||
| .unwrap_or_default() | |||
| .to_str() | |||
| .unwrap_or_default() | |||
| { | |||
| "yaml" | "yml" => Figment::new().merge(Yaml::file(config_path)), | |||
| "json" => Figment::new().merge(Json::file(config_path)), | |||
| "toml" => Figment::new().merge(Toml::file(config_path)), | |||
| ext => { | |||
| eprintln!("unsupport config file format: {ext:?}"); | |||
| std::process::exit(1); | |||
| } | |||
| }; | |||
| let conf = match raw_config.extract::<Config>() { | |||
| Ok(s) => s, | |||
| Err(e) => { | |||
| eprintln!("It looks like your config is invalid. The following error occurred: {e}"); | |||
| std::process::exit(1); | |||
| } | |||
| }; | |||
| CONFIG.set(conf).expect("config should be set"); | |||
| } | |||
| pub fn get() -> &'static Config { | |||
| CONFIG.get().unwrap() | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| pub struct Config { | |||
| #[serde(default = "default_listen_addr")] | |||
| pub listen_addr: String, | |||
| #[serde(default = "default_endpoint")] | |||
| pub endpoint: Option<String>, | |||
| pub providers: Vec<ProviderConfig>, | |||
| pub models: Vec<ModelConfig>, | |||
| pub mcp: Option<McpConfig>, | |||
| // #[serde(default = "default_true")] | |||
| // pub support_tool: bool, | |||
| } | |||
| fn default_listen_addr() -> String { | |||
| "0.0.0.0:8008".to_owned() | |||
| } | |||
| fn default_endpoint() -> Option<String> { | |||
| Some("v1".to_owned()) | |||
| } | |||
| // fn default_true() -> bool { | |||
| // true | |||
| // } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| #[serde(tag = "kind", rename_all = "snake_case")] | |||
| pub enum ProviderConfig { | |||
| Gemini(GeminiConfig), | |||
| Deepseek(DeepseekConfig), | |||
| Openai(OpenaiConfig), | |||
| Dora(DoraConfig), | |||
| } | |||
| impl ProviderConfig { | |||
| pub fn id(&self) -> &str { | |||
| match self { | |||
| ProviderConfig::Gemini(config) => &config.id, | |||
| ProviderConfig::Deepseek(config) => &config.id, | |||
| ProviderConfig::Openai(config) => &config.id, | |||
| ProviderConfig::Dora(config) => &config.id, | |||
| } | |||
| } | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| pub struct GeminiConfig { | |||
| pub id: String, | |||
| #[serde(default = "default_gemini_api_key")] | |||
| pub api_key: String, | |||
| #[serde(default = "default_gemini_api_url")] | |||
| pub api_url: String, | |||
| #[serde(default)] | |||
| pub proxy: bool, | |||
| } | |||
| fn default_gemini_api_key() -> String { | |||
| std::env::var("GEMINI_API_KEY").unwrap_or_default() | |||
| } | |||
| fn default_gemini_api_url() -> String { | |||
| std::env::var("GEMINI_API_URL").unwrap_or_else(|_|"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent".to_owned()) | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| pub struct DeepseekConfig { | |||
| pub id: String, | |||
| #[serde(default = "default_deepseek_api_key")] | |||
| pub api_key: String, | |||
| #[serde(default = "default_deepseek_api_url")] | |||
| pub api_url: String, | |||
| #[serde(default)] | |||
| pub proxy: bool, | |||
| } | |||
| fn default_deepseek_api_key() -> String { | |||
| std::env::var("DEEPSEEK_API_KEY").unwrap_or_default() | |||
| } | |||
| fn default_deepseek_api_url() -> String { | |||
| std::env::var("DEEPSEEK_API_URL").unwrap_or_else(|_| "https://api.deepseek.com".to_owned()) | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| pub struct OpenaiConfig { | |||
| pub id: String, | |||
| #[serde(default = "default_openai_api_key")] | |||
| pub api_key: String, | |||
| #[serde(default = "default_openai_api_url")] | |||
| pub api_url: String, | |||
| #[serde(default)] | |||
| pub proxy: bool, | |||
| } | |||
| fn default_openai_api_key() -> String { | |||
| std::env::var("OPENAI_API_KEY").unwrap_or_default() | |||
| } | |||
| fn default_openai_api_url() -> String { | |||
| std::env::var("OPENAI_API_URL") | |||
| .unwrap_or_else(|_| "https://api.openai.com/v1/chat/completions".to_owned()) | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| pub struct DoraConfig { | |||
| pub id: String, | |||
| pub output: String, | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| #[serde(tag = "model", rename_all = "snake_case")] | |||
| pub struct ModelConfig { | |||
| pub id: String, | |||
| pub object: Option<String>, | |||
| pub created: Option<u32>, | |||
| pub owned_by: Option<String>, | |||
| // #[serde(default)] | |||
| // pub default: bool, | |||
| pub route: ModelRouteConfig, | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| #[serde(tag = "route", rename_all = "snake_case")] | |||
| pub struct ModelRouteConfig { | |||
| pub provider: String, | |||
| pub model: Option<String>, | |||
| } | |||
| #[derive(Clone, Default, Debug, Deserialize)] | |||
| pub struct McpConfig { | |||
| #[serde(default)] | |||
| pub servers: Vec<McpServerConfig>, | |||
| } | |||
| #[derive(Debug, Serialize, Deserialize, Clone)] | |||
| pub struct McpServerConfig { | |||
| pub name: String, | |||
| #[serde(flatten)] | |||
| pub transport: McpServerTransportConfig, | |||
| } | |||
| #[derive(Debug, Serialize, Deserialize, Clone)] | |||
| #[serde(tag = "protocol", rename_all = "snake_case")] | |||
| pub enum McpServerTransportConfig { | |||
| Streamable { | |||
| url: String, | |||
| }, | |||
| Sse { | |||
| url: String, | |||
| }, | |||
| Stdio { | |||
| command: String, | |||
| #[serde(default)] | |||
| args: Vec<String>, | |||
| #[serde(default)] | |||
| envs: HashMap<String, String>, | |||
| }, | |||
| } | |||
| impl McpServerTransportConfig { | |||
| pub async fn start(&self) -> eyre::Result<RunningService<RoleClient, ()>> { | |||
| let client = match self { | |||
| McpServerTransportConfig::Streamable { url } => { | |||
| for _ in 0..5 { | |||
| let transport = | |||
| rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); | |||
| match ().serve(transport).await { | |||
| Ok(client) => return Ok(client), | |||
| Err(e) => { | |||
| println!("failed to start streamable transport: {e}"); | |||
| tracing::warn!("failed to start streamable transport: {e}"); | |||
| tokio::time::sleep(std::time::Duration::from_secs(2)).await; | |||
| } | |||
| } | |||
| } | |||
| eyre::bail!("failed to start streamable transport after 5 attempts"); | |||
| } | |||
| McpServerTransportConfig::Sse { url } => { | |||
| let transport = | |||
| rmcp::transport::sse_client::SseClientTransport::start(url.to_owned()).await?; | |||
| ().serve(transport).await? | |||
| } | |||
| McpServerTransportConfig::Stdio { | |||
| command, | |||
| args, | |||
| envs, | |||
| } => { | |||
| let transport = rmcp::transport::TokioChildProcess::new( | |||
| tokio::process::Command::new(command).configure(|cmd| { | |||
| cmd.args(args) | |||
| .envs(envs) | |||
| .stderr(Stdio::inherit()) | |||
| .stdout(Stdio::inherit()); | |||
| }), | |||
| )?; | |||
| ().serve(transport).await? | |||
| } | |||
| }; | |||
| Ok(client) | |||
| } | |||
| } | |||
| impl Config { | |||
| pub async fn create_mcp_clients( | |||
| &self, | |||
| ) -> eyre::Result<HashMap<String, RunningService<RoleClient, ()>>> { | |||
| let mut clients = HashMap::new(); | |||
| if let Some(mcp_config) = &self.mcp { | |||
| for server in &mcp_config.servers { | |||
| let client = server.transport.start().await?; | |||
| clients.insert(server.name.clone(), client); | |||
| } | |||
| } | |||
| Ok(clients) | |||
| } | |||
| fn create_chat_clients( | |||
| &self, | |||
| server_events_tx: mpsc::Sender<ServerEvent>, | |||
| ) -> HashMap<String, Arc<dyn ChatClient>> { | |||
| let mut clients: HashMap<String, Arc<dyn ChatClient>> = HashMap::new(); | |||
| for provider in &self.providers { | |||
| let client: Arc<dyn ChatClient> = match provider { | |||
| ProviderConfig::Gemini(config) => Arc::new(GeminiClient::new(config)), | |||
| ProviderConfig::Deepseek(config) => Arc::new(DeepseekClient::new(config)), | |||
| ProviderConfig::Openai(config) => Arc::new(OpenaiClient::new(config)), | |||
| ProviderConfig::Dora(config) => { | |||
| Arc::new(DoraClient::new(config, server_events_tx.clone())) | |||
| } | |||
| }; | |||
| clients.insert(provider.id().to_owned(), client); | |||
| } | |||
| clients | |||
| } | |||
| pub async fn create_session( | |||
| &self, | |||
| server_events_tx: mpsc::Sender<ServerEvent>, | |||
| ) -> eyre::Result<ChatSession> { | |||
| let mut tool_set = ToolSet::default(); | |||
| if self.mcp.is_some() { | |||
| let mcp_clients = self.create_mcp_clients().await?; | |||
| for (name, client) in mcp_clients.iter() { | |||
| tracing::info!("load MCP tool: {name}"); | |||
| let server = client.peer().clone(); | |||
| let tools = crate::get_mcp_tools(server.clone()).await?; | |||
| for tool in tools { | |||
| tool_set.add_tool(tool); | |||
| } | |||
| } | |||
| tool_set.set_clients(mcp_clients); | |||
| } | |||
| Ok(ChatSession::new( | |||
| self.create_chat_clients(server_events_tx), | |||
| get().models.clone(), | |||
| tool_set, | |||
| )) | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| use salvo::async_trait; | |||
| use salvo::http::{StatusCode, StatusError}; | |||
| use salvo::prelude::{Depot, Request, Response, Writer}; | |||
| use thiserror::Error; | |||
| use crate::ServerEvent; | |||
| #[allow(clippy::large_enum_variant)] | |||
| #[derive(Error, Debug)] | |||
| pub enum AppError { | |||
| #[error("public: `{0}`")] | |||
| Public(String), | |||
| #[error("internal: `{0}`")] | |||
| Internal(String), | |||
| #[error("salvo internal error: `{0}`")] | |||
| Salvo(#[from] ::salvo::Error), | |||
| #[error("serde json: `{0}`")] | |||
| SerdeJson(#[from] serde_json::error::Error), | |||
| #[error("http: `{0}`")] | |||
| StatusError(#[from] salvo::http::StatusError), | |||
| #[error("http parse: `{0}`")] | |||
| HttpParse(#[from] salvo::http::ParseError), | |||
| #[error("recv: `{0}`")] | |||
| Recv(#[from] tokio::sync::oneshot::error::RecvError), | |||
| #[error("send: `{0}`")] | |||
| Send(#[from] tokio::sync::mpsc::error::SendError<ServerEvent>), | |||
| #[error("canceled: `{0}`")] | |||
| Canceled(#[from] futures::channel::oneshot::Canceled), | |||
| #[error("error report: `{0}`")] | |||
| ErrReport(#[from] eyre::Report), | |||
| // #[error("reqwest: `{0}`")] | |||
| // Reqwest(#[from] reqwest::Error), | |||
| } | |||
| impl AppError { | |||
| pub fn public<S: Into<String>>(msg: S) -> Self { | |||
| Self::Public(msg.into()) | |||
| } | |||
| pub fn internal<S: Into<String>>(msg: S) -> Self { | |||
| Self::Internal(msg.into()) | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl Writer for AppError { | |||
| async fn write(mut self, _req: &mut Request, _depot: &mut Depot, res: &mut Response) { | |||
| let code = match &self { | |||
| AppError::StatusError(e) => e.code, | |||
| _ => StatusCode::INTERNAL_SERVER_ERROR, | |||
| }; | |||
| res.status_code(code); | |||
| let data = match self { | |||
| AppError::Salvo(e) => StatusError::internal_server_error().brief(e.to_string()), | |||
| AppError::Public(msg) => StatusError::internal_server_error().brief(msg), | |||
| AppError::Internal(_msg) => StatusError::internal_server_error(), | |||
| AppError::StatusError(e) => e, | |||
| e => StatusError::internal_server_error().brief(e.to_string()), | |||
| }; | |||
| res.render(data); | |||
| } | |||
| } | |||
| @@ -0,0 +1,183 @@ | |||
| use std::collections::HashMap; | |||
| use dora_node_api::{ | |||
| arrow::array::{AsArray, StringArray}, | |||
| dora_core::config::DataId, | |||
| merged::{MergeExternalSend, MergedEvent}, | |||
| DoraNode, Event, MetadataParameters, Parameter, | |||
| }; | |||
| use eyre::{Context, ContextCompat}; | |||
| use futures::channel::oneshot; | |||
| use outfox_openai::spec::{ | |||
| ChatChoice, ChatCompletionResponseMessage, CompletionUsage, CreateChatCompletionRequest, | |||
| CreateChatCompletionResponse, FinishReason, Role, | |||
| }; | |||
| use salvo::cors::*; | |||
| use salvo::prelude::*; | |||
| use tokio::sync::mpsc; | |||
| mod client; | |||
| mod error; | |||
| mod routing; | |||
| mod utils; | |||
| use error::AppError; | |||
| mod config; | |||
| mod session; | |||
| use session::ChatSession; | |||
| mod tool; | |||
| use tool::get_mcp_tools; | |||
| use utils::gen_call_id; | |||
| pub type AppResult<T> = Result<T, AppError>; | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| config::init(); | |||
| let (server_events_tx, server_events_rx) = mpsc::channel(3); | |||
| let server_events = tokio_stream::wrappers::ReceiverStream::new(server_events_rx); | |||
| let config = config::get(); | |||
| let chat_session = config | |||
| .create_session(server_events_tx.clone()) | |||
| .await | |||
| .context("failed to create chat session")?; | |||
| let mut reply_channels: HashMap< | |||
| String, | |||
| ( | |||
| oneshot::Sender<CreateChatCompletionResponse>, | |||
| u32, | |||
| Option<String>, | |||
| ), | |||
| > = HashMap::new(); | |||
| let acceptor = TcpListener::new(&config.listen_addr).bind().await; | |||
| tokio::spawn(async move { | |||
| let service = Service::new(routing::root(config.endpoint.clone(), chat_session.into())) | |||
| .hoop( | |||
| Cors::new() | |||
| .allow_origin(AllowOrigin::any()) | |||
| .allow_methods(AllowMethods::any()) | |||
| .allow_headers(AllowHeaders::any()) | |||
| .into_handler(), | |||
| ); | |||
| Server::new(acceptor).serve(service).await; | |||
| if let Err(err) = server_events_tx.send(ServerEvent::Result(Ok(()))).await { | |||
| tracing::warn!("server result channel closed: {err}"); | |||
| } | |||
| }); | |||
| let (mut node, events) = DoraNode::init_from_env()?; | |||
| let merged = events.merge_external_send(server_events); | |||
| let events = futures::executor::block_on_stream(merged); | |||
| for event in events { | |||
| match event { | |||
| MergedEvent::External(event) => match event { | |||
| ServerEvent::Result(server_result) => { | |||
| server_result.context("server failed")?; | |||
| break; | |||
| } | |||
| ServerEvent::CallNode { | |||
| output, | |||
| request, | |||
| reply, | |||
| } => { | |||
| let mut metadata = MetadataParameters::default(); | |||
| let call_id = gen_call_id(); | |||
| metadata.insert("__dora_call_id".into(), Parameter::String(call_id.clone())); | |||
| let texts = request | |||
| .messages | |||
| .iter() | |||
| .map(|msg| msg.to_texts().join("\n")) | |||
| .collect::<Vec<_>>(); | |||
| node.send_output( | |||
| DataId::from(output), | |||
| Default::default(), | |||
| StringArray::from(texts), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| reply_channels.insert(call_id, (reply, 0_u32, Some(request.model))); | |||
| } | |||
| }, | |||
| MergedEvent::Dora(event) => match event { | |||
| Event::Input { id, data, metadata } => { | |||
| let Some(Parameter::String(call_id)) = | |||
| metadata.parameters.get("__dora_call_id") | |||
| else { | |||
| tracing::warn!("No call ID found in metadata for id: {}", id); | |||
| continue; | |||
| }; | |||
| let (reply_channel, prompt_tokens, model) = | |||
| reply_channels.remove(call_id).context("no reply channel")?; | |||
| let data = data.as_string::<i32>(); | |||
| let data = data.iter().fold("".to_string(), |mut acc, s| { | |||
| if let Some(s) = s { | |||
| acc.push('\n'); | |||
| acc.push_str(s); | |||
| } | |||
| acc | |||
| }); | |||
| let data = CreateChatCompletionResponse { | |||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||
| object: "chat.completion".to_string(), | |||
| created: chrono::Utc::now().timestamp() as u32, | |||
| model: model.unwrap_or_default(), | |||
| usage: Some(CompletionUsage { | |||
| prompt_tokens, | |||
| completion_tokens: data.len() as u32, | |||
| total_tokens: prompt_tokens + data.len() as u32, | |||
| prompt_tokens_details: None, | |||
| completion_tokens_details: None, | |||
| }), | |||
| choices: vec![ChatChoice { | |||
| index: 0, | |||
| message: ChatCompletionResponseMessage { | |||
| role: Role::Assistant, | |||
| content: Some(data), | |||
| tool_calls: None, | |||
| audio: None, | |||
| refusal: None, | |||
| }, | |||
| finish_reason: Some(FinishReason::Stop), | |||
| logprobs: None, | |||
| }], | |||
| service_tier: None, | |||
| system_fingerprint: None, | |||
| }; | |||
| if reply_channel.send(data).is_err() { | |||
| tracing::warn!( | |||
| "failed to send chat completion reply because channel closed early" | |||
| ); | |||
| } | |||
| } | |||
| Event::Stop(_) => { | |||
| break; | |||
| } | |||
| Event::InputClosed { id, .. } => { | |||
| tracing::info!("Input channel closed for id: {}", id); | |||
| } | |||
| event => { | |||
| eyre::bail!("unexpected event: {:#?}", event) | |||
| } | |||
| }, | |||
| } | |||
| } | |||
| Ok(()) | |||
| } | |||
| #[allow(clippy::large_enum_variant)] | |||
| pub enum ServerEvent { | |||
| Result(eyre::Result<()>), | |||
| CallNode { | |||
| output: String, | |||
| request: CreateChatCompletionRequest, | |||
| reply: oneshot::Sender<CreateChatCompletionResponse>, | |||
| }, | |||
| } | |||
| @@ -0,0 +1,118 @@ | |||
| use std::sync::Arc; | |||
| use outfox_openai::spec::{CreateChatCompletionRequest, Model}; | |||
| use salvo::prelude::*; | |||
| use crate::session::ChatSession; | |||
| use crate::AppResult; | |||
| pub fn root(endpoint: Option<String>, chat_session: Arc<ChatSession>) -> Router { | |||
| Router::with_hoop(affix_state::inject(chat_session)) | |||
| .push( | |||
| if let Some(endpoint) = endpoint { | |||
| Router::with_path(endpoint) | |||
| } else { | |||
| Router::new() | |||
| } | |||
| .push(Router::with_path("chat/completions").post(chat_completions)) | |||
| .push(Router::with_path("models").get(list_models)) | |||
| .push(Router::with_path("embeddings").get(todo)) | |||
| .push(Router::with_path("files").get(todo)) | |||
| .push(Router::with_path("chunks").get(todo)) | |||
| .push(Router::with_path("info").get(todo)) | |||
| .push(Router::with_path("realtime").get(todo)), | |||
| ) | |||
| .push(Router::with_path("{**path}").get(index)) | |||
| } | |||
| #[handler] | |||
| async fn todo(res: &mut Response) { | |||
| res.render(Text::Plain("TODO")); | |||
| } | |||
| #[handler] | |||
| async fn index(res: &mut Response) { | |||
| res.render(Text::Plain("Hello")); | |||
| } | |||
| #[handler] | |||
| async fn list_models(depot: &mut Depot, res: &mut Response) { | |||
| let chat_session = depot | |||
| .obtain::<Arc<ChatSession>>() | |||
| .expect("chat session must be exists"); | |||
| let mut models = Vec::new(); | |||
| for model in &chat_session.models { | |||
| // TODO: fill correct data | |||
| models.push(Model { | |||
| id: model.id.clone(), | |||
| object: model.object.clone().unwrap_or("object".into()), | |||
| created: model.created.unwrap_or_default(), | |||
| owned_by: model.owned_by.clone().unwrap_or("dora".into()), | |||
| }); | |||
| } | |||
| res.render(Json(models)); | |||
| } | |||
| #[handler] | |||
| async fn chat_completions( | |||
| req: &mut Request, | |||
| depot: &mut Depot, | |||
| res: &mut Response, | |||
| ) -> AppResult<()> { | |||
| tracing::info!("Handling the coming chat completion request."); | |||
| let chat_session = depot | |||
| .obtain::<Arc<ChatSession>>() | |||
| .expect("chat session must be exists"); | |||
| tracing::info!("Prepare the chat completion request."); | |||
| let mut chat_request = match req.parse_json::<CreateChatCompletionRequest>().await { | |||
| Ok(chat_requst) => chat_requst, | |||
| Err(e) => { | |||
| println!( | |||
| "parse request error: {e}, payload: {}", | |||
| String::from_utf8_lossy(req.payload().await?) | |||
| ); | |||
| return Err(e.into()); | |||
| } | |||
| }; | |||
| // check if the user id is provided | |||
| if chat_request.user.is_none() { | |||
| chat_request.user = Some(crate::utils::gen_chat_id()) | |||
| }; | |||
| let id = chat_request.user.clone().unwrap(); | |||
| // log user id | |||
| tracing::info!("user: {}", chat_request.user.clone().unwrap()); | |||
| // let stream = chat_request.stream; | |||
| // let (tx, rx) = oneshot::channel(); | |||
| // request_tx | |||
| // .send(ServerEvent::CompletionRequest { | |||
| // request: chat_request, | |||
| // reply: tx, | |||
| // }) | |||
| // .await?; | |||
| // if let Some(true) = stream { | |||
| // // let result = async { | |||
| // // let chat_completion_object = rx.await?; | |||
| // // Ok::<_, AppError>(serde_json::to_string(&chat_completion_object)?) | |||
| // // }; | |||
| // let result = chat_session.chat(chat_request).await?; | |||
| // let stream = futures::stream::once(result); | |||
| // let _ = res.add_header("Content-Type", "text/event-stream", true); | |||
| // let _ = res.add_header("Cache-Control", "no-cache", true); | |||
| // let _ = res.add_header("Connection", "keep-alive", true); | |||
| // let _ = res.add_header("user", id, true); | |||
| // res.stream(stream); | |||
| // } else { | |||
| let response = chat_session.chat(chat_request).await?; | |||
| let _ = res.add_header("user", id, true); | |||
| res.render(Json(response)); | |||
| // }; | |||
| tracing::info!("Send the chat completion response."); | |||
| Ok(()) | |||
| } | |||
| @@ -0,0 +1,221 @@ | |||
| use std::collections::HashMap; | |||
| use std::sync::{Arc, Mutex}; | |||
| use eyre::Result; | |||
| use outfox_openai::spec::{ | |||
| ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, | |||
| ChatCompletionRequestUserMessageContent, ChatCompletionResponseMessage, ChatCompletionTool, | |||
| ChatCompletionToolType, CreateChatCompletionRequest, CreateChatCompletionResponse, | |||
| FunctionCall, FunctionObject, | |||
| }; | |||
| use crate::client::ChatClient; | |||
| use crate::config::ModelConfig; | |||
| use crate::tool::ToolSet; | |||
| pub struct ChatSession { | |||
| pub chat_clients: HashMap<String, Arc<dyn ChatClient>>, | |||
| pub models: Vec<ModelConfig>, | |||
| pub tool_set: ToolSet, | |||
| pub messages: Mutex<Vec<ChatCompletionRequestMessage>>, | |||
| } | |||
| impl ChatSession { | |||
| pub fn new( | |||
| chat_clients: HashMap<String, Arc<dyn ChatClient>>, | |||
| models: Vec<ModelConfig>, | |||
| tool_set: ToolSet, | |||
| ) -> Self { | |||
| Self { | |||
| chat_clients, | |||
| models, | |||
| tool_set, | |||
| messages: Default::default(), | |||
| } | |||
| } | |||
| // pub fn default_model(&self) -> Option<&str> { | |||
| // let model = self | |||
| // .models | |||
| // .iter() | |||
| // .find(|model| model.default) | |||
| // .map(|model| &*model.id); | |||
| // if model.is_none() { | |||
| // self.models.first().map(|model| &*model.id) | |||
| // } else { | |||
| // model | |||
| // } | |||
| // } | |||
| pub fn route(&self, model: &str) -> Option<(Arc<dyn ChatClient>, String)> { | |||
| let model = self.models.iter().find(|m| m.id == model)?; | |||
| let route = &model.route; | |||
| let client = self.chat_clients.get(&route.provider)?; | |||
| Some(( | |||
| client.clone(), | |||
| route.model.clone().unwrap_or(model.id.clone()), | |||
| )) | |||
| } | |||
| // pub fn add_system_prompt(&mut self, prompt: impl ToString) { | |||
| // let mut messages = self.messages.lock().expect("messages should locked"); | |||
| // messages.push( | |||
| // ChatCompletionRequestSystemMessage { | |||
| // content: PartibleTextContent::Text(prompt.to_string()), | |||
| // name: None, | |||
| // } | |||
| // .into(), | |||
| // ); | |||
| // } | |||
| // pub fn get_tools(&self) -> Vec<Arc<dyn ToolTrait>> { | |||
| // self.tool_set.tools() | |||
| // } | |||
| pub async fn analyze_tool_call(&self, response: &ChatCompletionResponseMessage) { | |||
| let mut tool_calls_func = Vec::new(); | |||
| if let Some(tool_calls) = response.tool_calls.as_ref() { | |||
| for tool_call in tool_calls { | |||
| // if tool_call.r#type == "function" { | |||
| tool_calls_func.push(tool_call.function.clone()); | |||
| // } | |||
| } | |||
| } else { | |||
| // check if message contains tool call | |||
| if let Some(text) = &response.content { | |||
| if text.contains("Tool:") { | |||
| let lines: Vec<&str> = text.split('\n').collect(); | |||
| // simple parse tool call | |||
| let mut tool_name = None; | |||
| let mut args_text = Vec::new(); | |||
| let mut parsing_args = false; | |||
| for line in lines { | |||
| if line.starts_with("Tool:") { | |||
| tool_name = line.strip_prefix("Tool:").map(|s| s.trim().to_string()); | |||
| parsing_args = false; | |||
| } else if line.starts_with("Inputs:") { | |||
| parsing_args = true; | |||
| } else if parsing_args { | |||
| args_text.push(line.trim()); | |||
| } | |||
| } | |||
| if let Some(name) = tool_name { | |||
| tool_calls_func.push(FunctionCall { | |||
| name, | |||
| arguments: args_text.join("\n"), | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // call tool | |||
| for tool_call in tool_calls_func { | |||
| let tool = self.tool_set.get_tool(&tool_call.name); | |||
| if let Some(tool) = tool { | |||
| // call tool | |||
| let args = serde_json::from_str::<serde_json::Value>(&tool_call.arguments) | |||
| .unwrap_or_default(); | |||
| match tool.call(args).await { | |||
| Ok(result) => { | |||
| if result.is_error.is_some_and(|b| b) { | |||
| let mut messages = | |||
| self.messages.lock().expect("messages should locked"); | |||
| messages.push( | |||
| ChatCompletionRequestUserMessage::new( | |||
| "tool call failed, mcp call error", | |||
| ) | |||
| .into(), | |||
| ); | |||
| } else if let Some(contents) = &result.content { | |||
| contents.iter().for_each(|content| { | |||
| if let Some(content_text) = content.as_text() { | |||
| let json_result = serde_json::from_str::<serde_json::Value>( | |||
| &content_text.text, | |||
| ) | |||
| .unwrap_or_default(); | |||
| let pretty_result = | |||
| serde_json::to_string_pretty(&json_result).unwrap(); | |||
| tracing::debug!("call tool result: {}", pretty_result); | |||
| let mut messages = | |||
| self.messages.lock().expect("messages should locked"); | |||
| messages.push( | |||
| ChatCompletionRequestUserMessage::new(format!( | |||
| "call tool result: {pretty_result}" | |||
| )) | |||
| .into(), | |||
| ); | |||
| } | |||
| }); | |||
| } | |||
| } | |||
| Err(e) => { | |||
| tracing::error!("tool call failed: {}", e); | |||
| let mut messages = self.messages.lock().expect("messages should locked"); | |||
| messages.push( | |||
| ChatCompletionRequestUserMessage { | |||
| content: ChatCompletionRequestUserMessageContent::Text(format!( | |||
| "tool call failed: {e}" | |||
| )), | |||
| name: None, | |||
| } | |||
| .into(), | |||
| ); | |||
| } | |||
| } | |||
| } else { | |||
| println!("tool not found: {}", tool_call.name); | |||
| } | |||
| } | |||
| } | |||
| pub async fn chat( | |||
| &self, | |||
| mut request: CreateChatCompletionRequest, | |||
| ) -> Result<CreateChatCompletionResponse> { | |||
| { | |||
| let mut messages = self.messages.lock().expect("messages should locked"); | |||
| for message in std::mem::take(&mut request.messages) { | |||
| messages.push(message); | |||
| } | |||
| request.messages = messages.clone(); | |||
| } | |||
| let tools = self.tool_set.tools(); | |||
| let tool_definitions = if !tools.is_empty() { | |||
| Some( | |||
| tools | |||
| .iter() | |||
| .map(|tool| ChatCompletionTool { | |||
| kind: ChatCompletionToolType::Function, | |||
| function: FunctionObject { | |||
| name: tool.name(), | |||
| description: Some(tool.description()), | |||
| parameters: Some(tool.parameters()), | |||
| strict: None, | |||
| }, | |||
| }) | |||
| .collect(), | |||
| ) | |||
| } else { | |||
| None | |||
| }; | |||
| let (client, model) = self.route(&request.model).expect("failed to route model"); | |||
| request.model = model.clone(); | |||
| request.tools = tool_definitions; | |||
| // send request | |||
| let response = client.complete(request).await?; | |||
| // get choice | |||
| if let Some(choice) = response.choices.first() { | |||
| // analyze tool call | |||
| self.analyze_tool_call(&choice.message).await; | |||
| let request = { | |||
| let messages = self.messages.lock().expect("messages should locked"); | |||
| CreateChatCompletionRequest::new(model, messages.clone()) | |||
| }; | |||
| client.complete(request).await | |||
| } else { | |||
| Ok(response) | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,95 @@ | |||
| use std::{collections::HashMap, sync::Arc}; | |||
| use eyre::Result; | |||
| use rmcp::{ | |||
| model::{CallToolRequestParam, CallToolResult, Tool as McpTool}, | |||
| service::{RunningService, ServerSink}, | |||
| RoleClient, | |||
| }; | |||
| use salvo::async_trait; | |||
| use serde_json::Value; | |||
| #[async_trait] | |||
| pub trait Tool: Send + Sync { | |||
| fn name(&self) -> String; | |||
| fn description(&self) -> String; | |||
| fn parameters(&self) -> Value; | |||
| async fn call(&self, args: Value) -> Result<CallToolResult>; | |||
| } | |||
| pub struct McpToolAdapter { | |||
| tool: McpTool, | |||
| server: ServerSink, | |||
| } | |||
| impl McpToolAdapter { | |||
| pub fn new(tool: McpTool, server: ServerSink) -> Self { | |||
| Self { tool, server } | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl Tool for McpToolAdapter { | |||
| fn name(&self) -> String { | |||
| self.tool.name.clone().to_string() | |||
| } | |||
| fn description(&self) -> String { | |||
| self.tool | |||
| .description | |||
| .clone() | |||
| .unwrap_or_default() | |||
| .to_string() | |||
| } | |||
| fn parameters(&self) -> Value { | |||
| serde_json::to_value(&self.tool.input_schema).unwrap_or(serde_json::json!({})) | |||
| } | |||
| async fn call(&self, args: Value) -> Result<CallToolResult> { | |||
| let arguments = match args { | |||
| Value::Object(map) => Some(map), | |||
| _ => None, | |||
| }; | |||
| let call_result = self | |||
| .server | |||
| .call_tool(CallToolRequestParam { | |||
| name: self.tool.name.clone(), | |||
| arguments, | |||
| }) | |||
| .await?; | |||
| Ok(call_result) | |||
| } | |||
| } | |||
| #[derive(Default)] | |||
| pub struct ToolSet { | |||
| tools: HashMap<String, Arc<dyn Tool>>, | |||
| clients: HashMap<String, RunningService<RoleClient, ()>>, | |||
| } | |||
| impl ToolSet { | |||
| pub fn set_clients(&mut self, clients: HashMap<String, RunningService<RoleClient, ()>>) { | |||
| self.clients = clients; | |||
| } | |||
| pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) { | |||
| self.tools.insert(tool.name(), Arc::new(tool)); | |||
| } | |||
| pub fn get_tool(&self, name: &str) -> Option<Arc<dyn Tool>> { | |||
| self.tools.get(name).cloned() | |||
| } | |||
| pub fn tools(&self) -> Vec<Arc<dyn Tool>> { | |||
| self.tools.values().cloned().collect() | |||
| } | |||
| } | |||
| pub async fn get_mcp_tools(server: ServerSink) -> Result<Vec<McpToolAdapter>> { | |||
| let tools = server.list_all_tools().await?; | |||
| Ok(tools | |||
| .into_iter() | |||
| .map(|tool| McpToolAdapter::new(tool, server.clone())) | |||
| .collect()) | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| pub(crate) fn gen_call_id() -> String { | |||
| format!("call-{}", uuid::Uuid::new_v4()) | |||
| } | |||
| pub(crate) fn gen_chat_id() -> String { | |||
| format!("chatcmpl-{}", uuid::Uuid::new_v4()) | |||
| } | |||
| pub fn get_env_or_value(value: &str) -> String { | |||
| if let Some(stripped) = value.strip_prefix("env:") { | |||
| std::env::var(stripped).unwrap_or_else(|_| value.to_string()) | |||
| } else { | |||
| value.to_string() | |||
| } | |||
| } | |||
| @@ -0,0 +1,30 @@ | |||
| [package] | |||
| name = "dora-mcp-server" | |||
| version.workspace = true | |||
| edition.workspace = true | |||
| rust-version.workspace = true | |||
| documentation.workspace = true | |||
| description.workspace = true | |||
| license.workspace = true | |||
| repository.workspace = true | |||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |||
| [dependencies] | |||
| chrono = "0.4.31" | |||
| dora-node-api = { workspace = true, features = ["tracing"] } | |||
| eyre = "0.6.8" | |||
| futures = "0.3.31" | |||
| indexmap = { version = "2.6.0", features = ["serde"] } | |||
| mime_guess = "2.0.4" | |||
| rmcp = { version = "0.2.1", features = ["server"] } | |||
| salvo = { version = "0.80.0", default-features = false, features = ["affix-state", "cors", "server", "http1", "http2"] } | |||
| serde = { version = "1.0.130", features = ["derive"] } | |||
| serde_json = "1.0.68" | |||
| thiserror = "2.0.12" | |||
| tokio = { version = "1.46.1", features = ["full"] } | |||
| tokio-stream = "0.1.11" | |||
| tracing = "0.1.27" | |||
| url = "2.2.2" | |||
| uuid = { version = "1.10", features = ["v4"] } | |||
| figment = { version = "0.10.0", features = ["env", "json", "toml", "yaml"] } | |||
| @@ -0,0 +1,66 @@ | |||
| # Dora MCP Server | |||
| This node can provide an MCP Server, which will proxy the request to one or more other nodes in the dora application. | |||
| Dora MCP Server is still experimental and may change in the future. | |||
| ## How to use | |||
| ```yaml | |||
| nodes: | |||
| - id: mcp-server | |||
| build: cargo build -p dora-mcp-server --release | |||
| path: ../../target/release/dora-mcp-server | |||
| outputs: | |||
| - counter | |||
| inputs: | |||
| counter_reply: counter/reply | |||
| env: | |||
| MCP_SERVER_CONFIG: config.toml | |||
| ``` | |||
| use `MCP_SERVER_CONFIG` set config file, it supports toml, json or yaml format. | |||
| An example config file: | |||
| ```toml | |||
| name = "MCP Server Example" | |||
| version = "0.1.0" | |||
| # You can set your custom listen address and endpoint here. | |||
| # Default listen address is "0.0.0.0:8008" and endpoint is "mcp". | |||
| # In this example, the final service url is: http://0.0.0.0:8181/mcp | |||
| listen_addr = "0.0.0.0:8181" | |||
| endpoint = "mcp" | |||
| [[mcp_tools]] | |||
| name = "counter_decrement" # (Required) type: String, Unique identifier for the tool | |||
| title = "Decrement Counter" # (Optional) type: String, Human-readable name of the tool for display purposes | |||
| input_schema = "empty_object.json" # (Required) JSON Schema defining expected parameters | |||
| output = "counter" # (Required) type: String, Set the output name | |||
| [mcp_tools.annotations] # (Optional) Additional properties describing a Tool to clients | |||
| title = "decrement current value of the counter" # type: String, A human-readable title for the tool | |||
| [[mcp_tools]] | |||
| name = "counter_increment" | |||
| title = "Increment Counter" | |||
| input_schema = "empty_object.json" | |||
| output = "counter" | |||
| [mcp_tools.annotations] | |||
| title = "Increment current value of the counter" | |||
| [[mcp_tools]] | |||
| name = "counter_get_value" | |||
| title = "Get Counter Value" | |||
| input_schema = "empty_object.json" | |||
| output = "counter" | |||
| [mcp_tools.annotations] | |||
| title = "Get the current value of the counter" | |||
| ``` | |||
| You can use mpc inspector to test: | |||
| ```bash | |||
| npx @modelcontextprotocol/inspector | |||
| ``` | |||
| @@ -0,0 +1,102 @@ | |||
| use std::path::{Path, PathBuf}; | |||
| use std::sync::OnceLock; | |||
| use figment::providers::{Env, Format, Json, Toml, Yaml}; | |||
| use figment::Figment; | |||
| use rmcp::model::{JsonObject, ToolAnnotations}; | |||
| use serde::{Deserialize, Serialize}; | |||
| pub static CONFIG: OnceLock<Config> = OnceLock::new(); | |||
| fn figment_from_path<P: AsRef<Path>>(path: P) -> Figment { | |||
| let ext = path | |||
| .as_ref() | |||
| .extension() | |||
| .and_then(|s| s.to_str()) | |||
| .unwrap_or_default(); | |||
| match ext { | |||
| "yaml" | "yml" => Figment::new().merge(Yaml::file(path)), | |||
| "json" => Figment::new().merge(Json::file(path)), | |||
| "toml" => Figment::new().merge(Toml::file(path)), | |||
| _ => panic!("Unsupported config file format: {ext}"), | |||
| } | |||
| } | |||
| pub fn init() { | |||
| let config_file = Env::var("CONFIG").unwrap_or("config.toml".into()); | |||
| let config_path = PathBuf::from(config_file); | |||
| if !config_path.exists() { | |||
| eprintln!("Config file not found at: {}", config_path.display()); | |||
| std::process::exit(1); | |||
| } | |||
| let raw_config = figment_from_path(config_path); | |||
| let conf = match raw_config.extract::<Config>() { | |||
| Ok(s) => s, | |||
| Err(e) => { | |||
| eprintln!("It looks like your config is invalid. The following error occurred: {e}"); | |||
| std::process::exit(1); | |||
| } | |||
| }; | |||
| CONFIG.set(conf).expect("config should be set"); | |||
| } | |||
| pub fn get() -> &'static Config { | |||
| CONFIG.get().unwrap() | |||
| } | |||
| #[derive(Clone, Debug, Deserialize)] | |||
| pub struct Config { | |||
| #[serde(default = "default_listen_addr")] | |||
| pub listen_addr: String, | |||
| #[serde(default = "default_endpoint")] | |||
| pub endpoint: Option<String>, | |||
| pub name: String, | |||
| pub version: String, | |||
| pub mcp_tools: Vec<McpToolConfig>, | |||
| } | |||
| fn default_listen_addr() -> String { | |||
| "0.0.0.0:8008".to_owned() | |||
| } | |||
| fn default_endpoint() -> Option<String> { | |||
| Some("mcp".to_owned()) | |||
| } | |||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | |||
| pub struct McpToolConfig { | |||
| /// Unique identifier for the tool | |||
| pub name: String, | |||
| /// Optional human-readable name of the tool for display purposes | |||
| #[serde(skip_serializing_if = "Option::is_none")] | |||
| pub title: Option<String>, | |||
| /// Human-readable description of functionality | |||
| #[serde(skip_serializing_if = "Option::is_none")] | |||
| pub description: Option<String>, | |||
| /// A JSON Schema object defining the expected parameters for the tool | |||
| pub input_schema: InputSchema, | |||
| #[serde(skip_serializing_if = "Option::is_none")] | |||
| /// Optional properties describing tool behavior | |||
| pub annotations: Option<ToolAnnotations>, | |||
| pub output: String, | |||
| } | |||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | |||
| #[serde(untagged)] | |||
| pub enum InputSchema { | |||
| Object(JsonObject), | |||
| FilePath(String), | |||
| } | |||
| impl InputSchema { | |||
| pub fn schema(&self) -> JsonObject { | |||
| match self { | |||
| InputSchema::Object(obj) => obj.clone(), | |||
| InputSchema::FilePath(path) => figment_from_path(path) | |||
| .extract::<JsonObject>() | |||
| .expect("should read input schema from file"), | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,57 @@ | |||
| use salvo::async_trait; | |||
| use salvo::http::{StatusCode, StatusError}; | |||
| use salvo::prelude::{Depot, Request, Response, Writer}; | |||
| use thiserror::Error; | |||
| #[derive(Error, Debug)] | |||
| pub enum AppError { | |||
| #[error("public: `{0}`")] | |||
| Public(String), | |||
| #[error("internal: `{0}`")] | |||
| Internal(String), | |||
| #[error("salvo internal error: `{0}`")] | |||
| Salvo(#[from] ::salvo::Error), | |||
| #[error("serde json: `{0}`")] | |||
| SerdeJson(#[from] serde_json::error::Error), | |||
| #[error("http: `{0}`")] | |||
| StatusError(#[from] salvo::http::StatusError), | |||
| #[error("http parse: `{0}`")] | |||
| HttpParse(#[from] salvo::http::ParseError), | |||
| #[error("recv: `{0}`")] | |||
| Recv(#[from] tokio::sync::oneshot::error::RecvError), | |||
| #[error("canceled: `{0}`")] | |||
| Canceled(#[from] futures::channel::oneshot::Canceled), | |||
| #[error("error report: `{0}`")] | |||
| ErrReport(#[from] eyre::Report), | |||
| // #[error("reqwest: `{0}`")] | |||
| // Reqwest(#[from] reqwest::Error), | |||
| } | |||
| impl AppError { | |||
| pub fn public<S: Into<String>>(msg: S) -> Self { | |||
| Self::Public(msg.into()) | |||
| } | |||
| pub fn internal<S: Into<String>>(msg: S) -> Self { | |||
| Self::Internal(msg.into()) | |||
| } | |||
| } | |||
| #[async_trait] | |||
| impl Writer for AppError { | |||
| async fn write(mut self, _req: &mut Request, _depot: &mut Depot, res: &mut Response) { | |||
| let code = match &self { | |||
| AppError::StatusError(e) => e.code, | |||
| _ => StatusCode::INTERNAL_SERVER_ERROR, | |||
| }; | |||
| res.status_code(code); | |||
| let data = match self { | |||
| AppError::Salvo(e) => StatusError::internal_server_error().brief(e.to_string()), | |||
| AppError::Public(msg) => StatusError::internal_server_error().brief(msg), | |||
| AppError::Internal(_msg) => StatusError::internal_server_error(), | |||
| AppError::StatusError(e) => e, | |||
| e => StatusError::internal_server_error().brief(e.to_string()), | |||
| }; | |||
| res.render(data); | |||
| } | |||
| } | |||
| @@ -0,0 +1,176 @@ | |||
| use std::collections::HashMap; | |||
| use std::sync::Arc; | |||
| use dora_node_api::{ | |||
| arrow::array::{AsArray, StringArray}, | |||
| dora_core::config::DataId, | |||
| merged::{MergeExternalSend, MergedEvent}, | |||
| DoraNode, Event, MetadataParameters, Parameter, | |||
| }; | |||
| use eyre::{Context, ContextCompat}; | |||
| use futures::channel::oneshot; | |||
| use rmcp::model::{ClientRequest, JsonRpcRequest}; | |||
| use salvo::cors::*; | |||
| use salvo::prelude::*; | |||
| use tokio::sync::mpsc; | |||
| mod mcp_server; | |||
| use mcp_server::McpServer; | |||
| mod error; | |||
| mod routing; | |||
| use error::AppError; | |||
| mod config; | |||
| mod utils; | |||
| use config::Config; | |||
| use utils::gen_call_id; | |||
| pub type AppResult<T> = Result<T, crate::AppError>; | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| config::init(); | |||
| let (server_events_tx, server_events_rx) = mpsc::channel(3); | |||
| let server_events = tokio_stream::wrappers::ReceiverStream::new(server_events_rx); | |||
| let mut reply_channels: HashMap<String, oneshot::Sender<String>> = HashMap::new(); | |||
| let config = config::get(); | |||
| let mcp_server = Arc::new(McpServer::new(config)); | |||
| salvo::http::request::set_global_secure_max_size(8_000_000); // set max size to 8MB | |||
| let acceptor = TcpListener::new(&config.listen_addr).bind().await; | |||
| tokio::spawn({ | |||
| let server_events_tx = server_events_tx.clone(); | |||
| let mcp_server = mcp_server.clone(); | |||
| async move { | |||
| let service = Service::new(routing::root( | |||
| config.endpoint.clone(), | |||
| mcp_server, | |||
| server_events_tx.clone(), | |||
| )) | |||
| .hoop( | |||
| Cors::new() | |||
| .allow_origin(AllowOrigin::any()) | |||
| .allow_methods(AllowMethods::any()) | |||
| .allow_headers(AllowHeaders::any()) | |||
| .into_handler(), | |||
| ); | |||
| Server::new(acceptor).serve(service).await; | |||
| if let Err(err) = server_events_tx.send(ServerEvent::Result(Ok(()))).await { | |||
| tracing::warn!("server result channel closed: {err}"); | |||
| } | |||
| } | |||
| }); | |||
| let (mut node, events) = DoraNode::init_from_env()?; | |||
| let merged = events.merge_external_send(server_events); | |||
| let events = futures::executor::block_on_stream(merged); | |||
| for event in events { | |||
| match event { | |||
| MergedEvent::External(event) => match event { | |||
| ServerEvent::Result(server_result) => { | |||
| server_result.context("server failed")?; | |||
| break; | |||
| } | |||
| ServerEvent::CallNode { | |||
| output, | |||
| data, | |||
| reply, | |||
| } => { | |||
| let mut metadata = MetadataParameters::default(); | |||
| let call_id = gen_call_id(); | |||
| metadata.insert("__dora_call_id".into(), Parameter::String(call_id.clone())); | |||
| node.send_output( | |||
| DataId::from(output.clone()), | |||
| metadata, | |||
| StringArray::from(vec![data]), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| reply_channels.insert(call_id, reply); | |||
| } | |||
| }, | |||
| MergedEvent::Dora(event) => match event { | |||
| Event::Input { id, data, metadata } => { | |||
| match id.as_str() { | |||
| "request" => { | |||
| let data = data.as_string::<i32>().iter().fold( | |||
| "".to_string(), | |||
| |mut acc, s| { | |||
| if let Some(s) = s { | |||
| acc.push('\n'); | |||
| acc.push_str(s); | |||
| } | |||
| acc | |||
| }, | |||
| ); | |||
| let request = | |||
| serde_json::from_str::<JsonRpcRequest<ClientRequest>>(&data) | |||
| .context("failed to parse call tool from string")?; | |||
| if let Ok(result) = | |||
| mcp_server.handle_request(request, &server_events_tx).await | |||
| { | |||
| node.send_output( | |||
| DataId::from("response".to_owned()), | |||
| metadata.parameters, | |||
| StringArray::from( | |||
| vec![serde_json::to_string(&result).unwrap()], | |||
| ), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| } | |||
| } | |||
| _ => { | |||
| let Some(Parameter::String(call_id)) = | |||
| metadata.parameters.get("__dora_call_id") | |||
| else { | |||
| tracing::warn!("No call ID found in metadata for id: {}", id); | |||
| continue; | |||
| }; | |||
| let reply_channel = | |||
| reply_channels.remove(call_id).context("no reply channel")?; | |||
| let data = data.as_string::<i32>(); | |||
| let data = data.iter().fold("".to_string(), |mut acc, s| { | |||
| if let Some(s) = s { | |||
| acc.push('\n'); | |||
| acc.push_str(s); | |||
| } | |||
| acc | |||
| }); | |||
| if reply_channel.send(data).is_err() { | |||
| tracing::warn!("failed to send reply because channel closed early"); | |||
| } | |||
| // node.send_output(DataId::from("response".to_owned()), metadata, data) | |||
| // .context("failed to send dora output")?; | |||
| } | |||
| }; | |||
| } | |||
| Event::Stop(_) => { | |||
| break; | |||
| } | |||
| Event::InputClosed { id, .. } => { | |||
| tracing::info!("Input channel closed for id: {}", id); | |||
| } | |||
| event => { | |||
| eyre::bail!("unexpected event: {:#?}", event) | |||
| } | |||
| }, | |||
| } | |||
| } | |||
| Ok(()) | |||
| } | |||
| enum ServerEvent { | |||
| Result(eyre::Result<()>), | |||
| CallNode { | |||
| output: String, | |||
| data: String, | |||
| reply: oneshot::Sender<String>, | |||
| }, | |||
| } | |||
| @@ -0,0 +1,129 @@ | |||
| use std::sync::Arc; | |||
| use futures::channel::oneshot; | |||
| use rmcp::model::{ | |||
| CallToolRequest, CallToolResult, EmptyResult, Implementation, InitializeResult, | |||
| ListToolsResult, ProtocolVersion, ServerCapabilities, ServerResult, Tool, | |||
| }; | |||
| use rmcp::model::{ClientRequest, JsonRpcRequest}; | |||
| use serde::Deserialize; | |||
| use tokio::sync::mpsc; | |||
| use crate::{Config, ServerEvent}; | |||
| #[derive(Debug)] | |||
| pub struct McpServer { | |||
| tools: Vec<McpTool>, | |||
| server_info: Implementation, | |||
| } | |||
| #[derive(Deserialize, Debug)] | |||
| pub struct McpTool { | |||
| pub output: String, | |||
| #[serde(flatten)] | |||
| pub inner: Tool, | |||
| } | |||
| impl McpServer { | |||
| pub fn new(config: &Config) -> Self { | |||
| let mut tools = Vec::new(); | |||
| for tool_config in &config.mcp_tools { | |||
| let tool = Tool { | |||
| name: tool_config.name.clone().into(), | |||
| description: tool_config.description.clone().map(|s| s.into()), | |||
| input_schema: Arc::new(tool_config.input_schema.schema()), | |||
| annotations: tool_config.annotations.clone(), | |||
| }; | |||
| tools.push(McpTool { | |||
| inner: tool, | |||
| output: tool_config.output.clone(), | |||
| }); | |||
| } | |||
| Self { | |||
| tools, | |||
| server_info: Implementation { | |||
| name: config.name.clone(), | |||
| version: config.version.clone(), | |||
| }, | |||
| } | |||
| } | |||
| // pub fn tools(&self) -> Vec<&Tool> { | |||
| // self.tools.iter().map(|t| &t.inner).collect() | |||
| // } | |||
| // pub fn server_info(&self) -> &Implementation { | |||
| // &self.server_info | |||
| // } | |||
| pub async fn handle_ping(&self) -> eyre::Result<EmptyResult> { | |||
| Ok(EmptyResult {}) | |||
| } | |||
| pub async fn handle_initialize(&self) -> eyre::Result<InitializeResult> { | |||
| Ok(InitializeResult { | |||
| protocol_version: ProtocolVersion::V_2025_03_26, | |||
| server_info: self.server_info.clone(), | |||
| capabilities: ServerCapabilities { | |||
| tools: Some(Default::default()), | |||
| ..Default::default() | |||
| }, | |||
| instructions: None, | |||
| }) | |||
| } | |||
| pub async fn handle_tools_list(&self) -> eyre::Result<ListToolsResult> { | |||
| Ok(ListToolsResult { | |||
| tools: self.tools.iter().map(|t| t.inner.clone()).collect(), | |||
| next_cursor: None, | |||
| }) | |||
| } | |||
| pub async fn handle_tools_call( | |||
| &self, | |||
| request: CallToolRequest, | |||
| request_tx: &mpsc::Sender<ServerEvent>, | |||
| ) -> eyre::Result<CallToolResult> { | |||
| let (tx, rx) = oneshot::channel(); | |||
| let tool = self | |||
| .tools | |||
| .iter() | |||
| .find(|t| t.inner.name == request.params.name) | |||
| .ok_or_else(|| eyre::eyre!("Tool not found: {}", request.params.name))?; | |||
| request_tx | |||
| .send(ServerEvent::CallNode { | |||
| output: tool.output.clone(), | |||
| data: serde_json::to_string(&request.params).unwrap(), | |||
| reply: tx, | |||
| }) | |||
| .await?; | |||
| let data: String = rx.await?; | |||
| serde_json::from_str(&data) | |||
| .map_err(|e| eyre::eyre!("Failed to parse call tool result: {e}")) | |||
| } | |||
| pub async fn handle_request( | |||
| &self, | |||
| rpc_request: JsonRpcRequest<ClientRequest>, | |||
| server_events_tx: &mpsc::Sender<ServerEvent>, | |||
| ) -> eyre::Result<ServerResult> { | |||
| match rpc_request.request { | |||
| ClientRequest::PingRequest(_request) => { | |||
| self.handle_ping().await.map(ServerResult::EmptyResult) | |||
| } | |||
| ClientRequest::InitializeRequest(_request) => self | |||
| .handle_initialize() | |||
| .await | |||
| .map(ServerResult::InitializeResult), | |||
| ClientRequest::ListToolsRequest(_request) => self | |||
| .handle_tools_list() | |||
| .await | |||
| .map(ServerResult::ListToolsResult), | |||
| ClientRequest::CallToolRequest(request) => self | |||
| .handle_tools_call(request, server_events_tx) | |||
| .await | |||
| .map(ServerResult::CallToolResult), | |||
| method => Err(eyre::eyre!("unexpected method: {:#?}", method)), | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,71 @@ | |||
| use std::sync::Arc; | |||
| use eyre::Context; | |||
| use rmcp::model::{ClientJsonRpcMessage, JsonRpcResponse, JsonRpcVersion2_0}; | |||
| use salvo::prelude::*; | |||
| use tokio::sync::mpsc; | |||
| use crate::{AppResult, McpServer, ServerEvent}; | |||
| pub fn root( | |||
| endpoint: Option<String>, | |||
| mcp_server: Arc<McpServer>, | |||
| server_events_tx: mpsc::Sender<ServerEvent>, | |||
| ) -> Router { | |||
| Router::with_hoop(affix_state::inject(mcp_server).inject(server_events_tx)).push( | |||
| if let Some(endpoint) = endpoint { | |||
| Router::with_path(endpoint) | |||
| } else { | |||
| Router::new() | |||
| } | |||
| .post(handle_post) | |||
| .delete(handle_delete), | |||
| ) | |||
| } | |||
| #[handler] | |||
| async fn handle_delete(res: &mut Response) { | |||
| res.render(Text::Plain("DELETE method is not supported")); | |||
| } | |||
| #[handler] | |||
| async fn handle_post(req: &mut Request, depot: &mut Depot, res: &mut Response) -> AppResult<()> { | |||
| tracing::debug!("Handling the coming chat completion request."); | |||
| let server_events_tx = depot | |||
| .obtain::<mpsc::Sender<ServerEvent>>() | |||
| .expect("server_events_tx must be exists"); | |||
| let mcp_server = depot | |||
| .obtain::<Arc<McpServer>>() | |||
| .expect("mcp server must be exists"); | |||
| tracing::debug!("Prepare the chat completion request."); | |||
| let rpc_message = serde_json::from_slice::<ClientJsonRpcMessage>(req.payload().await?) | |||
| .context("failed to parse request bodyxxx")?; | |||
| match rpc_message { | |||
| ClientJsonRpcMessage::Request(rpc_request) => { | |||
| let response = JsonRpcResponse { | |||
| jsonrpc: JsonRpcVersion2_0, | |||
| id: rpc_request.id.clone(), | |||
| result: mcp_server | |||
| .handle_request(rpc_request, server_events_tx) | |||
| .await | |||
| .unwrap(), | |||
| }; | |||
| res.render(Json(response)); | |||
| } | |||
| ClientJsonRpcMessage::Notification(_) | |||
| | ClientJsonRpcMessage::Response(_) | |||
| | ClientJsonRpcMessage::Error(_) => { | |||
| res.render(StatusCode::ACCEPTED); | |||
| } | |||
| _ => { | |||
| res.render( | |||
| StatusError::not_implemented().brief("Batch requests are not supported yet"), | |||
| ); | |||
| } | |||
| } | |||
| tracing::debug!("Send the chat completion response."); | |||
| Ok(()) | |||
| } | |||
| @@ -0,0 +1,3 @@ | |||
| pub(crate) fn gen_call_id() -> String { | |||
| format!("call-{}", uuid::Uuid::new_v4()) | |||
| } | |||