Skip to main content
While CopilotKit provides built-in agent implementations (BuiltInAgent, LangGraphAgent, CrewAIAgent), you can create custom agents for specialized requirements by implementing the agent protocol.

When to Create a Custom Agent

Consider building a custom agent when you need to:
  • Integrate with proprietary AI systems or APIs
  • Implement custom orchestration logic
  • Add specialized middleware or processing
  • Create domain-specific agent behaviors
  • Integrate with existing agent frameworks not directly supported

Agent Protocol Overview

All CopilotKit agents implement a common protocol defined by the Agent abstract base class:

TypeScript (Node.js)

import { AbstractAgent, RunAgentInput, BaseEvent } from "@ag-ui/client";
import { Observable } from "rxjs";

export class CustomAgent extends AbstractAgent {
  run(input: RunAgentInput): Observable<BaseEvent> {
    // Your implementation
  }

  clone() {
    // Create a copy of the agent
  }

  abortRun(): void {
    // Handle run cancellation
  }
}

Python

from copilotkit.agent import Agent
from typing import List, Optional
from copilotkit.types import Message
from copilotkit.action import ActionDict

class CustomAgent(Agent):
    def __init__(self, name: str, description: Optional[str] = None):
        super().__init__(name=name, description=description)
    
    async def execute(
        self,
        *,
        state: dict,
        messages: List[Message],
        thread_id: str,
        actions: Optional[List[ActionDict]] = None,
        **kwargs
    ):
        # Your implementation
        pass
    
    async def get_state(self, *, thread_id: str):
        # Return thread state
        pass

Creating a Custom TypeScript Agent

Basic Structure

import {
  AbstractAgent,
  RunAgentInput,
  BaseEvent,
  EventType,
  RunStartedEvent,
  TextMessageChunkEvent,
  RunFinishedEvent,
  RunErrorEvent
} from "@ag-ui/client";
import { Observable } from "rxjs";

export interface CustomAgentConfig {
  apiKey: string;
  endpoint: string;
  model?: string;
  temperature?: number;
}

export class CustomAgent extends AbstractAgent {
  private abortController?: AbortController;

  constructor(private config: CustomAgentConfig) {
    super();
  }

  run(input: RunAgentInput): Observable<BaseEvent> {
    return new Observable<BaseEvent>((subscriber) => {
      this.executeRun(input, subscriber);
      
      // Cleanup function
      return () => {
        this.abortController?.abort();
      };
    });
  }

  private async executeRun(
    input: RunAgentInput,
    subscriber: any
  ) {
    const abortController = new AbortController();
    this.abortController = abortController;

    try {
      // Emit run started event
      const startEvent: RunStartedEvent = {
        type: EventType.RUN_STARTED,
        threadId: input.threadId,
        runId: input.runId,
      };
      subscriber.next(startEvent);

      // Your agent logic here
      const response = await this.callYourAPI(input, abortController.signal);

      // Stream text chunks
      for await (const chunk of response) {
        if (abortController.signal.aborted) break;

        const textEvent: TextMessageChunkEvent = {
          type: EventType.TEXT_MESSAGE_CHUNK,
          role: "assistant",
          messageId: "msg-1",
          delta: chunk,
        };
        subscriber.next(textEvent);
      }

      // Emit run finished event
      const finishedEvent: RunFinishedEvent = {
        type: EventType.RUN_FINISHED,
        threadId: input.threadId,
        runId: input.runId,
      };
      subscriber.next(finishedEvent);
      subscriber.complete();
    } catch (error) {
      const errorEvent: RunErrorEvent = {
        type: EventType.RUN_ERROR,
        message: String(error),
      };
      subscriber.next(errorEvent);
      subscriber.error(error);
    } finally {
      this.abortController = undefined;
    }
  }

  private async callYourAPI(input: RunAgentInput, signal: AbortSignal) {
    // Implement your API call
    const response = await fetch(this.config.endpoint, {
      method: "POST",
      headers: {
        "Content-Type": "application/json",
        "Authorization": `Bearer ${this.config.apiKey}`,
      },
      body: JSON.stringify({
        messages: input.messages,
        model: this.config.model,
        temperature: this.config.temperature,
      }),
      signal,
    });

    // Return async iterator for streaming
    return this.streamResponse(response);
  }

  private async *streamResponse(response: Response) {
    const reader = response.body?.getReader();
    if (!reader) return;

    const decoder = new TextDecoder();
    while (true) {
      const { done, value } = await reader.read();
      if (done) break;
      yield decoder.decode(value, { stream: true });
    }
  }

  clone() {
    const cloned = new CustomAgent(this.config);
    // Copy middlewares if needed
    // @ts-expect-error - accessing protected property
    cloned.middlewares = [...this.middlewares];
    return cloned;
  }

  abortRun(): void {
    this.abortController?.abort();
  }
}

Supporting Tool Calls

import {
  ToolCallStartEvent,
  ToolCallArgsEvent,
  ToolCallEndEvent,
  ToolCallResultEvent,
} from "@ag-ui/client";

private async handleToolCalls(
  toolCalls: any[],
  input: RunAgentInput,
  subscriber: any
) {
  for (const toolCall of toolCalls) {
    // Emit tool call start
    const startEvent: ToolCallStartEvent = {
      type: EventType.TOOL_CALL_START,
      parentMessageId: "msg-1",
      toolCallId: toolCall.id,
      toolCallName: toolCall.name,
    };
    subscriber.next(startEvent);

    // Emit tool call arguments
    const argsEvent: ToolCallArgsEvent = {
      type: EventType.TOOL_CALL_ARGS,
      toolCallId: toolCall.id,
      delta: JSON.stringify(toolCall.arguments),
    };
    subscriber.next(argsEvent);

    // Emit tool call end
    const endEvent: ToolCallEndEvent = {
      type: EventType.TOOL_CALL_END,
      toolCallId: toolCall.id,
    };
    subscriber.next(endEvent);

    // Execute tool and emit result
    const tool = input.tools.find(t => t.name === toolCall.name);
    if (tool) {
      const result = await this.executeTool(tool, toolCall.arguments);
      
      const resultEvent: ToolCallResultEvent = {
        type: EventType.TOOL_CALL_RESULT,
        role: "tool",
        messageId: "msg-tool-1",
        toolCallId: toolCall.id,
        content: JSON.stringify(result),
      };
      subscriber.next(resultEvent);
    }
  }
}

Creating a Custom Python Agent

Basic Structure

from copilotkit.agent import Agent
from copilotkit.types import Message
from copilotkit.action import ActionDict
from copilotkit.protocol import (
    emit_runtime_events,
    agent_state_message,
    text_message,
)
from typing import List, Optional, Dict, Any
import uuid
import json

class CustomAgent(Agent):
    """Custom agent implementation"""
    
    def __init__(
        self,
        name: str,
        api_key: str,
        endpoint: str,
        model: str = "default",
        description: Optional[str] = None
    ):
        super().__init__(name=name, description=description)
        self.api_key = api_key
        self.endpoint = endpoint
        self.model = model
        self.state_storage: Dict[str, Dict] = {}
    
    async def execute(
        self,
        *,
        state: dict,
        messages: List[Message],
        thread_id: str,
        actions: Optional[List[ActionDict]] = None,
        **kwargs
    ):
        """Execute the agent"""
        run_id = str(uuid.uuid4())
        
        try:
            # Call your API
            response = await self.call_api(
                messages=messages,
                state=state,
                actions=actions
            )
            
            # Emit text chunks
            async for chunk in self.stream_response(response):
                yield emit_runtime_events(
                    text_message(
                        id=f"msg-{run_id}",
                        content=chunk
                    )
                )
            
            # Store state
            self.state_storage[thread_id] = state
            
            # Emit final state
            yield emit_runtime_events(
                agent_state_message(
                    thread_id=thread_id,
                    agent_name=self.name,
                    node_name="main",
                    run_id=run_id,
                    active=False,
                    role="assistant",
                    state=json.dumps(state),
                    running=False
                )
            )
            
        except Exception as e:
            # Handle errors
            yield emit_runtime_events({
                "type": "error",
                "message": str(e)
            })
    
    async def call_api(
        self,
        messages: List[Message],
        state: dict,
        actions: Optional[List[ActionDict]] = None
    ) -> Any:
        """Call your custom API"""
        import aiohttp
        
        async with aiohttp.ClientSession() as session:
            async with session.post(
                self.endpoint,
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json"
                },
                json={
                    "messages": messages,
                    "state": state,
                    "model": self.model,
                    "tools": actions or []
                }
            ) as response:
                return response
    
    async def stream_response(self, response):
        """Stream response from API"""
        async for line in response.content:
            if line:
                yield line.decode('utf-8')
    
    async def get_state(self, *, thread_id: str):
        """Get stored state for a thread"""
        stored_state = self.state_storage.get(thread_id)
        
        if stored_state:
            return {
                "threadId": thread_id,
                "threadExists": True,
                "state": stored_state,
                "messages": stored_state.get("messages", [])
            }
        
        return {
            "threadId": thread_id,
            "threadExists": False,
            "state": {},
            "messages": []
        }
    
    def dict_repr(self):
        """Return dictionary representation"""
        return {
            "name": self.name,
            "description": self.description or "",
            "type": "custom"
        }

Supporting Tool Execution

async def execute_tool(
    self,
    tool_name: str,
    tool_args: dict,
    actions: List[ActionDict]
) -> Any:
    """Execute a tool by name"""
    # Find the tool definition
    tool = next(
        (a for a in actions if a["name"] == tool_name),
        None
    )
    
    if not tool:
        raise ValueError(f"Tool {tool_name} not found")
    
    # Call your tool execution logic
    result = await self.call_tool_api(tool_name, tool_args)
    
    return result

async def handle_tool_calls(
    self,
    tool_calls: List[dict],
    actions: List[ActionDict],
    run_id: str
):
    """Handle multiple tool calls"""
    results = []
    
    for tool_call in tool_calls:
        tool_name = tool_call["name"]
        tool_args = tool_call["arguments"]
        tool_id = tool_call["id"]
        
        # Execute the tool
        result = await self.execute_tool(tool_name, tool_args, actions)
        
        # Emit tool result
        yield emit_runtime_events({
            "type": "tool_result",
            "toolCallId": tool_id,
            "result": result
        })
        
        results.append(result)
    
    return results

Event Types Reference

Essential Events

RUN_STARTED
event
Emitted when agent execution begins
{
  type: EventType.RUN_STARTED,
  threadId: string,
  runId: string
}
TEXT_MESSAGE_CHUNK
event
Emitted for streaming text responses
{
  type: EventType.TEXT_MESSAGE_CHUNK,
  role: "assistant",
  messageId: string,
  delta: string
}
RUN_FINISHED
event
Emitted when agent execution completes successfully
{
  type: EventType.RUN_FINISHED,
  threadId: string,
  runId: string
}
RUN_ERROR
event
Emitted when an error occurs
{
  type: EventType.RUN_ERROR,
  message: string
}
TOOL_CALL_START
event
{
  type: EventType.TOOL_CALL_START,
  parentMessageId: string,
  toolCallId: string,
  toolCallName: string
}
TOOL_CALL_ARGS
event
{
  type: EventType.TOOL_CALL_ARGS,
  toolCallId: string,
  delta: string  // Incremental JSON string
}
TOOL_CALL_END
event
{
  type: EventType.TOOL_CALL_END,
  toolCallId: string
}
TOOL_CALL_RESULT
event
{
  type: EventType.TOOL_CALL_RESULT,
  role: "tool",
  messageId: string,
  toolCallId: string,
  content: string
}

State Management Events

STATE_SNAPSHOT
event
Replace entire state
{
  type: EventType.STATE_SNAPSHOT,
  snapshot: any
}
STATE_DELTA
event
Apply incremental state updates
{
  type: EventType.STATE_DELTA,
  delta: JSONPatchOperation[]
}

Integration with CopilotRuntime

TypeScript

import { CopilotRuntime, InMemoryAgentRunner } from "@copilotkitnext/runtime";
import { CustomAgent } from "./custom-agent";

const customAgent = new CustomAgent({
  apiKey: process.env.CUSTOM_API_KEY!,
  endpoint: "https://api.example.com/v1/chat",
  model: "custom-model-v1",
  temperature: 0.7
});

const runtime = new CopilotRuntime({
  agents: {
    custom: customAgent,
    default: customAgent
  },
  runner: new InMemoryAgentRunner()
});

Python

from copilotkit.integrations.fastapi import add_fastapi_endpoint
from fastapi import FastAPI

app = FastAPI()

custom_agent = CustomAgent(
    name="custom_agent",
    api_key=os.environ["CUSTOM_API_KEY"],
    endpoint="https://api.example.com/v1/chat",
    model="custom-model-v1"
)

add_fastapi_endpoint(
    app,
    agents=[custom_agent],
    endpoint="/copilotkit"
)

Advanced Patterns

Multi-Step Reasoning

private async executeMultiStepReasoning(
  input: RunAgentInput,
  subscriber: any
) {
  const maxSteps = 5;
  let currentStep = 0;
  let messages = [...input.messages];

  while (currentStep < maxSteps) {
    const response = await this.callYourAPI({
      ...input,
      messages
    });

    // Check if tool calls are needed
    if (response.toolCalls && response.toolCalls.length > 0) {
      await this.handleToolCalls(response.toolCalls, input, subscriber);
      
      // Add tool results to messages and continue
      messages.push(response.message);
      currentStep++;
    } else {
      // No more tool calls, emit final response
      break;
    }
  }
}

Custom State Serialization

import pickle
import base64

class StatefulCustomAgent(CustomAgent):
    def serialize_state(self, state: dict) -> str:
        """Serialize state for storage"""
        serialized = pickle.dumps(state)
        return base64.b64encode(serialized).decode('utf-8')
    
    def deserialize_state(self, serialized: str) -> dict:
        """Deserialize state from storage"""
        decoded = base64.b64decode(serialized)
        return pickle.loads(decoded)
    
    async def get_state(self, *, thread_id: str):
        """Get state with custom serialization"""
        stored = self.state_storage.get(thread_id)
        if stored:
            return {
                "threadId": thread_id,
                "threadExists": True,
                "state": self.deserialize_state(stored),
                "messages": []
            }
        return super().get_state(thread_id=thread_id)

Request Batching

class BatchingAgent extends AbstractAgent {
  private requestQueue: RunAgentInput[] = [];
  private batchInterval = 100; // ms

  run(input: RunAgentInput): Observable<BaseEvent> {
    return new Observable<BaseEvent>((subscriber) => {
      this.requestQueue.push({ input, subscriber });
      
      if (this.requestQueue.length === 1) {
        setTimeout(() => this.processBatch(), this.batchInterval);
      }
    });
  }

  private async processBatch() {
    const batch = [...this.requestQueue];
    this.requestQueue = [];

    // Process all requests in batch
    const responses = await this.callBatchAPI(
      batch.map(b => b.input)
    );

    // Distribute responses
    batch.forEach((item, index) => {
      this.emitResponse(responses[index], item.subscriber);
    });
  }
}

Testing Custom Agents

Unit Testing

import { describe, it, expect } from "vitest";
import { CustomAgent } from "./custom-agent";
import { firstValueFrom } from "rxjs";

describe("CustomAgent", () => {
  it("should emit run started event", async () => {
    const agent = new CustomAgent({
      apiKey: "test-key",
      endpoint: "http://localhost:3000"
    });

    const events: any[] = [];
    const observable = agent.run({
      threadId: "test-thread",
      runId: "test-run",
      messages: [],
      tools: [],
      context: [],
      state: {}
    });

    observable.subscribe(event => events.push(event));

    await firstValueFrom(observable);

    expect(events[0].type).toBe("RUN_STARTED");
  });
});

Integration Testing

import pytest
from copilotkit.types import Message

@pytest.mark.asyncio
async def test_custom_agent_execution():
    agent = CustomAgent(
        name="test_agent",
        api_key="test-key",
        endpoint="http://localhost:3000"
    )
    
    messages = [
        Message(role="user", content="Hello")
    ]
    
    events = []
    async for event in agent.execute(
        state={},
        messages=messages,
        thread_id="test-thread"
    ):
        events.append(event)
    
    assert len(events) > 0
    assert any("Hello" in str(e) for e in events)

Best Practices

1. Always Handle Cancellation

Properly handle abort signals and cleanup resources:
const abortController = new AbortController();
try {
  await fetch(url, { signal: abortController.signal });
} finally {
  this.abortController = undefined;
}

2. Emit Events in Correct Order

Always emit RUN_STARTED first and RUN_FINISHED/RUN_ERROR last.

3. Handle Errors Gracefully

Catch and report errors as RUN_ERROR events rather than throwing exceptions.

4. Implement Clone Properly

Ensure your clone method creates a proper deep copy when needed.

5. Thread-Safe State Management

Use proper locking or atomic operations when managing shared state.

Next Steps

BuiltInAgent

Study the built-in implementation

LangGraph Integration

Learn from LangGraph integration

Event Protocol

Deep dive into the event system

Runtime API

Understand runtime integration