Newer
Older
cortex-hub / frontend / src / services / api / aiService.js
import { fetchWithAuth, API_BASE_URL, getUserId } from './apiClient';
import { convertPcmToFloat32 } from "../audioUtils";

/**
 * Sends an audio blob to the STT endpoint for transcription.
 */
export const transcribeAudio = async (audioBlob, providerName = null) => {
  const formData = new FormData();
  formData.append("audio_file", audioBlob, "audio.wav");

  const url = providerName 
    ? `/stt/transcribe?provider_name=${encodeURIComponent(providerName)}` 
    : '/stt/transcribe';

  const result = await fetchWithAuth(url, {
    method: "POST",
    body: formData
  });
  return result.transcript;
};

/**
 * Sends a text prompt to the LLM endpoint and gets a streaming text response (SSE).
 */
export const chatWithAI = async (sessionId, prompt, providerName = "gemini", onMessage = null) => {
  const userId = getUserId();
  const response = await fetch(`${API_BASE_URL}/sessions/${sessionId}/chat`, {
    method: "POST",
    headers: { "Content-Type": "application/json", "X-User-ID": userId },
    body: JSON.stringify({ prompt: prompt, provider_name: providerName }),
  });

  if (!response.ok) {
    let detail = "LLM API failed";
    try {
      const errBody = await response.json();
      detail = errBody.detail || JSON.stringify(errBody);
    } catch { }
    throw new Error(detail);
  }

  // Handle Streaming Response
  const reader = response.body.getReader();
  const decoder = new TextDecoder();
  let accumulatedBuffer = "";

  // Track final result for backward compatibility
  let fullAnswer = "";
  let lastMessageId = null;
  let finalProvider = providerName;

  while (true) {
    const { done, value } = await reader.read();
    if (done) break;

    accumulatedBuffer += decoder.decode(value, { stream: true });

    const parts = accumulatedBuffer.split("\n\n");
    accumulatedBuffer = parts.pop();

    for (const part of parts) {
      if (part.startsWith("data: ")) {
        try {
          const jsonStr = part.slice(6).trim();
          if (jsonStr) {
            const data = JSON.parse(jsonStr);

            // Accumulate content and info
            if (data.type === "content" && data.content) {
              fullAnswer += data.content;
            } else if (data.type === "finish") {
              lastMessageId = data.message_id;
              finalProvider = data.provider;
            }

            // Pass to streaming callback if provided
            if (onMessage) onMessage(data);
          }
        } catch (e) {
          console.warn("Failed to parse SSE line:", part, e);
        }
      }
    }
  }

  // Return the full result as the standard API used to
  return {
    answer: fullAnswer,
    message_id: lastMessageId,
    provider_used: finalProvider
  };
};

/**
 * Streams speech from the TTS endpoint and processes each chunk.
 */
export const streamSpeech = async (text, onData, onDone, providerName = null) => {
  const userId = getUserId();
  try {
    let url = `${API_BASE_URL}/speech?stream=true&as_wav=false`;
    if (providerName) {
      url += `&provider_name=${encodeURIComponent(providerName)}`;
    }

    const response = await fetch(url, {
      method: "POST",
      headers: { "Content-Type": "application/json", "X-User-ID": userId },
      body: JSON.stringify({ text }),
    }).catch(err => {
      console.error("Fetch transport error:", err);
      throw new Error(`Network transport failed: ${err.message}`);
    });

    if (!response.ok) {
      let detail = `HTTP error! Status: ${response.status}`;
      try {
        const errBody = await response.json();
        detail = errBody.detail || detail;
      } catch { }
      throw new Error(detail);
    }

    const totalChunks = parseInt(response.headers.get("X-TTS-Chunk-Count") || "0");
    const reader = response.body.getReader();
    let leftover = new Uint8Array(0);
    let chunkIndex = 0;

    try {
      while (true) {
        const { done, value: chunk } = await reader.read();
        if (done) break;

        let combined = new Uint8Array(leftover.length + chunk.length);
        combined.set(leftover);
        combined.set(chunk, leftover.length);

        let length = combined.length;
        if (length % 2 !== 0) length -= 1;

        const toConvert = combined.slice(0, length);
        leftover = combined.slice(length);
        const float32Raw = convertPcmToFloat32(toConvert);

        onData(float32Raw, totalChunks, ++chunkIndex);
      }
    } catch (readError) {
      console.error("Error reading response body stream:", readError);
      throw new Error(`Stream interrupted: ${readError.message}`);
    }
  } catch (error) {
    console.error("Failed to stream speech:", error);
    throw error;
  } finally {
    if (onDone) {
      await Promise.resolve(onDone());
    }
  }
};

/**
 * Verify a provider configuration
 */
export const verifyProvider = async (section, payload) => {
  return await fetchWithAuth(`/users/me/config/verify_${section}`, {
    method: "POST",
    body: payload
  });
};

/**
 * Fetches available models for a provider.
 */
export const getProviderModels = async (providerName, section = "llm") => {
  const params = new URLSearchParams({ provider_name: providerName, section: section });
  return await fetchWithAuth(`/users/me/config/models?${params.toString()}`);
};

/**
 * Fetches all underlying providers.
 */
export const getAllProviders = async (section = "llm") => {
  const params = new URLSearchParams({ section: section });
  return await fetchWithAuth(`/users/me/config/providers?${params.toString()}`);
};

/**
 * Fetches available TTS voice names.
 */
export const getVoices = async (provider = null, apiKey = null) => {
  try {
    const urlParams = new URLSearchParams();
    if (provider) urlParams.append('provider', provider);
    if (apiKey && apiKey !== 'null') urlParams.append('api_key', apiKey);

    const url = `/speech/voices?${urlParams.toString()}`;
    return await fetchWithAuth(url, { method: 'GET' });
  } catch (e) {
    console.error("Failed to fetch voices", e);
    return [];
  }
};