// vi:ft=cpp
/* SPDX-License-Identifier: MIT */
/*
 * Copyright (C) 2022, 2024 Kernkonzept GmbH.
 * Author(s): Stephan Gerhold <stephan.gerhold@kernkonzept.com>
 */
#pragma once

#include <cstring>
#include <functional>

#include <l4/cxx/exceptions>
#include <l4/cxx/minmax>
#include <l4/re/dataspace>
#include <l4/re/env>
#include <l4/re/error_helper>
#include <l4/re/util/unique_cap>
#include <l4/sys/consts.h>

#include <l4/l4virtio/client/l4virtio>
#include <l4/l4virtio/l4virtio>
#include <l4/l4virtio/virtio_net.h>
#include <l4/l4virtio/virtqueue>

namespace L4virtio { namespace Driver {

/**
 * Simple class for accessing a virtio net device.
 */
class Virtio_net_device : public L4virtio::Driver::Device
{
public:
  /**
   * Structure for a network packet (header including data) with maximum size,
   * assuming that no extra features have been negotiated.
   */
  struct Packet
  {
    l4virtio_net_header_t hdr;
    l4_uint8_t data[1500 + 14]; /* MTU + Ethernet header */
  };

  /**
   * Return the maximum receive queue size allowed by the device.
   * wait_rx() will return a descriptor number that is smaller than this size.
   */
  int rx_queue_size() const
  { return max_queue_size(0); }

  /**
   * Return the maximum transmit queue size allowed by the device.
   * tx() will fail if the amount of queued packets exceeds this size.
   */
  int tx_queue_size() const
  { return max_queue_size(1); }

  /**
   * Establish a connection to the device and set up shared memory.
   *
   * \param srvcap   IPC capability of the channel to the server.
   * \param queue_ds (optional) External queue dataspace. If this capability is
   *                 invalid, the function will attempt to allocate a dataspace
   *                 on its own. Note that the external queue dataspace must be
   *                 large enough.
   *
   * This function starts a handshake with the device and sets up the
   * virtqueues for communication and the additional data structures for
   * the network device.
   */
  void setup_device(L4::Cap<L4virtio::Device> srvcap,
                    L4::Cap<L4Re::Dataspace>
                      queue_ds = L4::Cap<L4Re::Dataspace>::Invalid)
  {
    // Contact device.
    driver_connect(srvcap, false);

    if (_config->device != L4VIRTIO_ID_NET)
      L4Re::chksys(-L4_ENODEV, "Device is not a network device.");

    if (_config->num_queues < 2)
      L4Re::chksys(-L4_EINVAL, "Invalid number of queues reported.");

    auto rxqsz = rx_queue_size();
    auto txqsz = tx_queue_size();

    // Allocate memory for RX/TX queue and RX/TX packet buffers
    auto rxqoff = 0;
    auto txqoff = l4_round_size(rxqoff + rxqsz * _rxq.total_size(rxqsz),
                                L4virtio::Virtqueue::Desc_align);
    auto rxpktoff = l4_round_size(txqoff + txqsz * _txq.total_size(txqsz),
                                  L4virtio::Virtqueue::Desc_align);
    auto txpktoff = rxpktoff + rxqsz * sizeof(Packet);
    auto totalsz = txpktoff + txqsz * sizeof(Packet);

    auto *e = L4Re::Env::env();

    if (queue_ds)
      _queue_ds = L4Re::Util::Unique_cap<L4Re::Dataspace>(queue_ds);
    else
      {
        _queue_ds = L4Re::chkcap(L4Re::Util::make_unique_cap<L4Re::Dataspace>(),
                                 "Allocate queue dataspace capability");
        L4Re::chksys(e->mem_alloc()->alloc(totalsz, _queue_ds.get(),
                                           L4Re::Mem_alloc::Continuous
                                           | L4Re::Mem_alloc::Pinned),
                     "Allocate memory for virtio structures");
      }

    L4Re::chksys(e->rm()->attach(&_queue_region, totalsz,
                                 L4Re::Rm::F::Search_addr | L4Re::Rm::F::RW,
                                 L4::Ipc::make_cap_rw(_queue_ds.get()), 0,
                                 L4_PAGESHIFT),
                 "Attach dataspace for virtio structures");

    l4_uint64_t devaddr;
    L4Re::chksys(register_ds(_queue_ds.get(), 0, totalsz, &devaddr),
                 "Register queue dataspace with device");

    _rxq.init_queue(rxqsz, _queue_region.get() + rxqoff);
    _txq.init_queue(txqsz, _queue_region.get() + txqoff);

    config_queue(0, rxqsz, devaddr + rxqoff,
                 devaddr + rxqoff + _rxq.avail_offset(),
                 devaddr + rxqoff + _rxq.used_offset());
    config_queue(1, txqsz, devaddr + txqoff,
                 devaddr + txqoff + _txq.avail_offset(),
                 devaddr + txqoff + _txq.used_offset());

    _rxpkts = reinterpret_cast<Packet*>(_queue_region.get() + rxpktoff);
    _txpkts = reinterpret_cast<Packet*>(_queue_region.get() + txpktoff);

    // Prepare descriptors to save work later
    for (l4_uint16_t descno = 0; descno < rxqsz; ++descno)
      {
        auto &desc = _rxq.desc(descno);
        desc.addr = L4virtio::Ptr<void>(devaddr + rxpktoff +
                                        descno * sizeof(Packet));
        desc.len = sizeof(Packet);
        desc.flags.write() = 1;
      }
    for (l4_uint16_t descno = 0; descno < txqsz; ++descno)
      {
        auto &desc = _txq.desc(descno);
        desc.addr = L4virtio::Ptr<void>(devaddr + txpktoff +
                                        descno * sizeof(Packet));
        desc.len = sizeof(Packet);
      }

    // Setup notification IRQ
    _driver_notification_irq =
      L4Re::chkcap(L4Re::Util::make_unique_cap<L4::Irq>(),
                   "Allocate notification capability");

    L4Re::chksys(l4_error(e->factory()->create(_driver_notification_irq.get())),
                 "Create irq for notifications from device");

    L4Re::chksys(_device->bind(0, _driver_notification_irq.get()),
                 "Bind driver notification interrupt");

    // Finish handshake with device
    l4virtio_set_feature(_config->driver_features_map,
                         L4VIRTIO_FEATURE_VERSION_1);
    l4virtio_set_feature(_config->driver_features_map, L4VIRTIO_NET_F_MAC);
    driver_acknowledge();
  }

  /**
   * Return a reference to the device configuration.
   */
  l4virtio_net_config_t const &device_config() const
  {
    return *_config->device_config<l4virtio_net_config_t>();
  }

  /**
   * Bind the rx notification IRQ to the specified thread.
   *
   * \param thread  Thread to bind the notification IRQ to.
   * \param label   Label to assign to the IRQ.
   */
  int bind_rx_notification_irq(L4::Cap<L4::Thread> thread, l4_umword_t label)
  {
    return l4_error(_driver_notification_irq->bind_thread(thread, label));
  }

  /**
   * Return a reference to the RX packet buffer of the specified descriptor,
   * e.g. from wait_rx().
   *
   * \param descno  Descriptor number in the virtio queue.
   */
  Packet &rx_pkt(l4_uint16_t descno)
  {
    if (descno >= _rxq.num())
      throw L4::Bounds_error("Invalid used descriptor number in RX queue");
    return _rxpkts[descno];
  }

  /**
   * Block until a network packet has been received from the device and return
   * the descriptor number.
   *
   * \pre The calling thread must be bound to the rx notification IRQ via
   *      `bind_rx_notification_irq()`.
   *
   * \param[out] len  (optional) Length of valid data in RX packet.
   *
   * \return  Descriptor number of received packet.
   *
   * The packet data can be obtained with rx_pkt(). finish_rx() should be
   * called after the packet buffer can be returned to the RX queue.
   *
   */
  l4_uint16_t wait_rx(l4_uint32_t *len = nullptr)
  {
    l4_uint16_t descno;
    // Wait until used descriptor becomes available.
    for (;;)
      {
        descno = _rxq.find_next_used(len);
        if (descno != Virtqueue::Eoq)
          break;

        L4Re::chksys(_driver_notification_irq->receive(), "Wait for RX");
      }

    if (len)
      // Ensure that the length provided by the device in wait_for_next_used()
      // is not larger than the buffer and subtract the length of the header.
      *len = cxx::min(*len - sizeof(_rxpkts[0].hdr), sizeof(_rxpkts[0].data));
    return descno;
  }

  /**
   * Free an RX descriptor number to make it available for the RX queue again.
   *
   * \param descno  Descriptor number in the virtio queue.
   *
   * Usually queue_rx() should be called afterwards to queue the freed
   * descriptor(s).
   */
  void finish_rx(l4_uint16_t descno)
  {
    _rxq.free_descriptor(descno, descno);
  }

  /**
   * Queue new available descriptors in the RX queue.
   */
  void queue_rx()
  {
    l4_uint16_t descno;
    while ((descno = _rxq.alloc_descriptor()) != Virtqueue::Eoq)
      _rxq.enqueue_descriptor(descno);
    notify(_rxq);
  }

  /**
   * Attempt to allocate a descriptor in the TX queue and transmit the packet,
   * after calling the prepare callback.
   *
   * \param prepare  Function that fills the packet with data, should return
   *                 the length of the data copied to the packet.
   *
   * \retval true   The packet was queued.
   * \retval false  TX queue is full.
   *
   * The prepare callback should fill the packet with data and return the
   * length of the packet data (without the size of the virtio-net packet
   * header).
   */
  bool tx(std::function<l4_uint32_t(Packet&)> prepare)
  {
    auto descno = _txq.alloc_descriptor();
    if (descno == Virtqueue::Eoq)
      {
        // Try again after cleaning old descriptors that have already been used
        free_used_tx_descriptors();
        descno = _txq.alloc_descriptor();
        if (descno == Virtqueue::Eoq)
          return false;
      }

    auto &pkt = _txpkts[descno];
    auto &desc = _txq.desc(descno);
    desc.len = sizeof(pkt.hdr) + prepare(pkt);
    send(_txq, descno);
    return true;
  }

private:
  void free_used_tx_descriptors()
  {
    l4_uint16_t used;
    while ((used = _txq.find_next_used()) != Virtqueue::Eoq)
      {
        if (used >= _txq.num())
          throw L4::Bounds_error("Invalid used descriptor number in TX queue");
        _txq.free_descriptor(used, used);
      }
  }

private:
  L4Re::Util::Unique_cap<L4Re::Dataspace> _queue_ds;
  L4Re::Rm::Unique_region<l4_uint8_t *> _queue_region;
  L4Re::Util::Unique_cap<L4::Irq> _driver_notification_irq;
  L4virtio::Driver::Virtqueue _rxq, _txq;
  Packet *_rxpkts, *_txpkts;
};

} }
