import type {
  ConversationMessage,
  ConversationMessageStream,
} from '@ceros/gemma-api-spec'
import { EventStreamContentType } from '@microsoft/fetch-event-source'
import { injected } from 'brandi'

import { DI_TYPE } from '@/di.types.js'
import { delayedRejection } from '@/utils/async'
import { ExponentialBackoff } from '@/utils/exponential-backoff'
import { HeartbeatMonitor } from '@/utils/heartbeat-monitor'

import type { ApiClient } from './api-client'
import { FatalError, RetriableError } from './api-client'

// FUTURE: these could come from api-spec if all defined there
type ConversationStreamEvent =
  | 'started'
  | 'ping'
  | 'message_generating'
  | 'message_finished'

// Assertion function to ensure all cases are covered
function assertNever(value: never): never {
  throw new Error(`Unhandled event: ${value}`)
}

export class StreamError extends Error {
  constructor(message = 'Stream error', cause?: any) {
    super(message, cause)
    this.name = 'StreamError'
  }
}

export type WithStreaming<T extends object, B extends boolean> = T & {
  streaming?: B
}

export class ConversationStreamService {
  constructor(private apiClient: ApiClient) {}

  public async connect(
    conversationId: string,
    callbacks: {
      onStart?: () => void
      onPing?: (pingAt: string) => void
      onMessageGenerating: (
        message: WithStreaming<ConversationMessageStream, true>,
      ) => void
      onMessageFinished: (
        message: WithStreaming<ConversationMessage, false>,
      ) => void
      onClose?: () => void
      onError?: (error?: any) => boolean
      onHeartbeatMissed: () => void
    },
    abortController: AbortController,
    timeout?: number,
  ) {
    const backoff = new ExponentialBackoff(1000, 32000, 30)
    const heartbeatMonitor = new HeartbeatMonitor(
      15000,
      callbacks.onHeartbeatMissed,
    )

    abortController.signal.addEventListener('abort', () => {
      heartbeatMonitor.stopMonitoring()
    })

    return new Promise<void>((resolveOpenedConnection, rejectConnection) => {
      const connection = new Promise<void>((resolve) => {
        this.apiClient.fetchEventSource(
          `/v2/conversations/${conversationId}/updates`,
          {
            signal: abortController.signal,
            async onopen(response) {
              // Resolve racing promise
              resolve()
              // We are fine and oficially connected
              resolveOpenedConnection()

              if (
                response.ok &&
                response.headers
                  .get('content-type')
                  ?.startsWith(EventStreamContentType)
              ) {
                heartbeatMonitor.startMonitoring()
                return // everything's good
              } else if (
                response.status >= 400 &&
                response.status < 500 &&
                response.status !== 429
              ) {
                // client-side errors are usually non-retriable:
                throw new FatalError()
              } else {
                throw new RetriableError()
              }
            },
            onmessage(payload) {
              backoff.reset()
              heartbeatMonitor.receivedHeartbeat()

              const eventName = payload.event as ConversationStreamEvent
              switch (eventName) {
                case 'started':
                  callbacks.onStart?.()
                  break
                case 'ping':
                  callbacks.onPing?.(payload.data)
                  break
                case 'message_generating':
                  const parsedGenerating = JSON.parse(payload.data)
                  parsedGenerating.streaming = true
                  callbacks.onMessageGenerating?.(parsedGenerating)
                  break
                case 'message_finished':
                  const parsedFinished = JSON.parse(payload.data)
                  parsedFinished.streaming = false
                  callbacks.onMessageFinished?.(parsedFinished)
                  break
                default:
                  assertNever(eventName) // this will throw a TS error if the defined event names don't have a case here
              }
            },
            onclose() {
              heartbeatMonitor.stopMonitoring()
              callbacks.onClose?.()
            },
            onerror(error) {
              heartbeatMonitor.stopMonitoring()

              const shouldExit = callbacks.onError?.(error)
              if (shouldExit || error instanceof FatalError) {
                throw error // rethrow to prevent a reconnect
              } else if (backoff.shouldRetry()) {
                const delay = backoff.getNextDelay()
                return delay // return a delay to reconnect
              } else {
                throw new FatalError('Max retries reached', error)
              }
            },
          },
        )
      })

      // Racing between resolved connection and state when we treat it like cancelled/behind firewall
      const cancelledConnection = timeout
        ? delayedRejection(timeout)
        : new Promise(() => {}) // always pending if no timeout provided
      Promise.race([connection, cancelledConnection]).catch(() => {
        rejectConnection()
      })
    })
  }
}

injected(ConversationStreamService, DI_TYPE.ApiClient)
