// vi:ft=cpp
/* SPDX-License-Identifier: MIT */
/*
 * Copyright (C) 2017-2021, 2024 Kernkonzept GmbH.
 * Author(s): Sarah Hoffmann <sarah.hoffmann@kernkonzept.com>
 *
 */
#pragma once

#include <l4/cxx/unique_ptr>
#include <l4/re/util/unique_cap>

#include <climits>

#include <l4/l4virtio/virtio.h>
#include <l4/l4virtio/virtio_block.h>
#include <l4/l4virtio/server/l4virtio>
#include <l4/sys/cxx/ipc_epiface>

namespace L4virtio { namespace Svr {

template <typename Ds_data> class Block_dev_base;

/**
 * A request to read or write data.
 */
template<typename Ds_data>
class Block_request
{
  friend class Block_dev_base<Ds_data>;
  enum { Header_size = sizeof(l4virtio_block_header_t) };

public:
  struct Data_block
  {
    /// Pointer to virtio memory descriptor.
    Driver_mem_region_t<Ds_data> *mem;
    /// Virtual address of the data block (in device space).
    void *addr;
    /// Length of datablock in bytes (max 4MB).
    l4_uint32_t len;

    Data_block() = default;

    Data_block(Driver_mem_region_t<Ds_data> *m, Virtqueue::Desc const &desc,
                Request_processor const *)
    : mem(m), addr(m->local(desc.addr)), len(desc.len)
    {}
  };



  /**
   * Compute the total size of the data in the request.
   *
   * \retval Size in bytes or 0 if there was an error.
   *
   * \throws L4::Runtime_error(-L4_EIO) Request has a bad format.
   *
   * Note that this operation is relatively expensive as
   * it has to iterate over the complete list of blocks.
   */
  unsigned data_size() const
  {
    Request_processor rp;
    Data_block data;

    rp.start(_mem_list, _request, &data);

    unsigned total = data.len;

    try
      {
        while (rp.has_more())
          {
            rp.next(_mem_list, &data);
            total += data.len;
          }
      }
    catch (Bad_descriptor const &e)
      {
        // need to convert the exception because e contains a raw pointer to rp
        throw L4::Runtime_error(-L4_EIO, "bad virtio descriptor");
      }

    if (total < Header_size + 1)
      throw L4::Runtime_error(-L4_EIO, "virtio request too short");

    return total - Header_size - 1;
  }

  /**
   * Check if the request contains more data blocks.
   */
  bool has_more()
  {
    // peek into the remaining data
    while (_data.len == 0 && _rp.has_more())
      _rp.next(_mem_list, &_data);

    // there always must be one byte left for status
    return (_data.len > 1 || _rp.has_more());
  }

  /**
   * Return next block in scatter-gather list.
   *
   * \return Information about the next data block.
   *
   * \throws L4::Runtime_error  No more data block is available.
   * \throws Bad_descriptor     Virtio request is corrupted.
   */
  Data_block next_block()
  {
    Data_block out;

    if (_data.len == 0)
      {
        if (!_rp.has_more())
          throw L4::Runtime_error(-L4_EEXIST,
                                  "No more data blocks in virtio request");

        if (_todo_blocks == 0)
          throw Bad_descriptor(&_rp, Bad_descriptor::Bad_size);
        --_todo_blocks;

        _rp.next(_mem_list, &_data);
      }

    if (_data.len > _max_block_size)
      throw Bad_descriptor(&_rp, Bad_descriptor::Bad_size);

    out = _data;

    if (!_rp.has_more())
      {
        --(out.len);
        _data.len = 1;
        _data.addr = static_cast<char *>(_data.addr) + out.len;
      }
    else
      _data.len = 0; // is consumed

    return out;
  }

  /// Return the block request header.
  l4virtio_block_header_t const &header() const
  { return _header; }

private:
  Block_request(Virtqueue::Request req, Driver_mem_list_t<Ds_data> *mem_list,
                unsigned max_blocks, l4_uint32_t max_block_size)
  : _mem_list(mem_list),
    _request(req),
    _todo_blocks(max_blocks),
    _max_block_size(max_block_size)
  {
    // read header which should be in the first block
    _rp.start(mem_list, _request, &_data);
    --_todo_blocks;

    if (_data.len < Header_size)
      throw Bad_descriptor(&_rp, Bad_descriptor::Bad_size);

    _header = *(static_cast<l4virtio_block_header_t *>(_data.addr));

    _data.addr = static_cast<char *>(_data.addr) + Header_size;
    _data.len -= Header_size;

    // if there is no space for status bit we cannot really recover
    if (!_rp.has_more() && _data.len == 0)
      throw Bad_descriptor(&_rp, Bad_descriptor::Bad_size);
  }

  int release_request(Virtqueue *queue, l4_uint8_t status, unsigned sz)
  {
    // write back status
    // If there was an error on the way or the status byte is in its
    // own block, fast-forward to the last block.
    if (_rp.has_more())
      {
        while (_rp.next(_mem_list, &_data) && _todo_blocks > 0)
          --_todo_blocks;

        if (_todo_blocks > 0 && _data.len > 0)
          *(static_cast<l4_uint8_t *>(_data.addr) + _data.len - 1) = status;
        else
          return -L4_EIO; // too many data blocks
      }
    else if (_data.len > 0)
      *(static_cast<l4_uint8_t *>(_data.addr)) = status;
    else
      return -L4_EIO; // no space for final status byte

    // now release the head
    queue->consumed(_request, sz);

    return L4_EOK;
  }

  /**
   * The list of memory areas for the device.
   * Points to the memory list of the parent device, which always must
   * have a longer livespan than the request.
   */
  Driver_mem_list_t<Ds_data> *_mem_list;
  /// Type and destination information.
  l4virtio_block_header_t _header;
  /// Request processor containing the current state.
  Request_processor _rp;
  /// Current data chunk in flight.
  Data_block _data;

  /// Original virtio request.
  Virtqueue::Request _request;
  /// Number of blocks that may still be processed.
  unsigned _todo_blocks;
  /// Maximum length of a single block.
  l4_uint32_t _max_block_size;
};

struct Block_features : public Dev_config::Features
{
  Block_features() = default;
  Block_features(l4_uint32_t raw) : Dev_config::Features(raw) {}

  /** Maximum size of any single segment is in size_max. */
  CXX_BITFIELD_MEMBER( 1,  1, size_max, raw);
  /** Maximum number of segments in a request is in seg_max. */
  CXX_BITFIELD_MEMBER( 2,  2, seg_max, raw);
  /** Disk-style geometry specified in geometry. */
  CXX_BITFIELD_MEMBER( 4,  4, geometry, raw);
  /** Device is read-only. */
  CXX_BITFIELD_MEMBER( 5,  5, ro, raw);
  /** Block size of disk is in blk_size. */
  CXX_BITFIELD_MEMBER( 6,  6, blk_size, raw);
  /** Cache flush command support. */
  CXX_BITFIELD_MEMBER( 9,  9, flush, raw);
  /** Device exports information about optimal IO alignment. */
  CXX_BITFIELD_MEMBER(10, 10, topology, raw);
  /** Device can toggle its cache between writeback and writethrough modes. */
  CXX_BITFIELD_MEMBER(11, 11, config_wce, raw);
  /** Device supports multiqueue. */
  CXX_BITFIELD_MEMBER(12, 12, mq, raw);
  /** Device can support discard command. */
  CXX_BITFIELD_MEMBER(13, 13, discard, raw);
  /** Device can support write zeroes command. */
  CXX_BITFIELD_MEMBER(14, 14, write_zeroes, raw);
};


/**
 * Base class for virtio block devices.
 *
 * Use this class as a base to implement your own specific block device.
 */
template <typename Ds_data>
class Block_dev_base : public L4virtio::Svr::Device_t<Ds_data>
{
private:
  L4Re::Util::Unique_cap<L4::Irq> _kick_guest_irq;
  Virtqueue _queue;
  unsigned _vq_max;
  l4_uint32_t _max_block_size = UINT_MAX;
  Dev_config_t<l4virtio_block_config_t> _dev_config;

public:
  typedef Block_request<Ds_data> Request;

protected:
  Block_features negotiated_features() const
  { return _dev_config.negotiated_features(0); }

  Block_features device_features() const
  { return _dev_config.host_features(0); }

  void set_device_features(Block_features df)
  { _dev_config.host_features(0) = df.raw; }

  /**
   * Sets the maximum size of any single segment reported to client.
   *
   * The limit is also applied to any incoming requests.
   * Requests with larger segments result in an IO error being
   * reported to the client. That means that process_request() can
   * safely make the assumption that all segments in the received
   * request are smaller.
   */
  void set_size_max(l4_uint32_t sz)
  {
    _dev_config.priv_config()->size_max = sz;
    Block_features df = device_features();
    df.size_max() = true;
    set_device_features(df);

    _max_block_size = sz;
  }

  /**
   * Sets the maximum number of segments in a request
   * that is reported to client.
   */
  void set_seg_max(l4_uint32_t sz)
  {
    _dev_config.priv_config()->seg_max = sz;
    Block_features df = device_features();
    df.seg_max() = true;
    set_device_features(df);
  }

  /**
   * Set disk geometry that is reported to the client.
   */
  void set_geometry(l4_uint16_t cylinders, l4_uint8_t heads, l4_uint8_t sectors)
  {
    l4virtio_block_config_t volatile *pc = _dev_config.priv_config();
    pc->geometry.cylinders = cylinders;
    pc->geometry.heads = heads;
    pc->geometry.sectors = sectors;
    Block_features df = device_features();
    df.geometry() = true;
    set_device_features(df);
  }

  /**
   * Sets block disk size to be reported to the client.
   *
   * Setting this does not change the logical sector size used
   * for addressing the device.
   */
  void set_blk_size(l4_uint32_t sz)
  {
    _dev_config.priv_config()->blk_size = sz;
    Block_features df = device_features();
    df.blk_size() = true;
    set_device_features(df);
  }

  /**
   * Sets the I/O alignment information reported back to the client.
   *
   * \param physical_block_exp Number of logical blocks per physical block(log2)
   * \param alignment_offset Offset of the first aligned logical block
   * \param min_io_size Suggested minimum I/O size in blocks
   * \param opt_io_size Optimal I/O size in blocks
   */
  void set_topology(l4_uint8_t physical_block_exp,
                    l4_uint8_t alignment_offset,
                    l4_uint32_t min_io_size,
                    l4_uint32_t opt_io_size)
  {
    l4virtio_block_config_t volatile *pc = _dev_config.priv_config();
    pc->topology.physical_block_exp = physical_block_exp;
    pc->topology.alignment_offset = alignment_offset;
    pc->topology.min_io_size = min_io_size;
    pc->topology.opt_io_size = opt_io_size;
    Block_features df = device_features();
    df.topology() = true;
    set_device_features(df);
  }

  /** Enables the flush command. */
  void set_flush()
  {
    Block_features df = device_features();
    df.flush() = true;
    set_device_features(df);
  }

  /** Sets cache mode and enables the writeback toggle.
   *
   * \param writeback Mode of the cache (0 for writethrough, 1 for writeback).
   */
  void set_config_wce(l4_uint8_t writeback)
  {
    l4virtio_block_config_t volatile *pc = _dev_config.priv_config();
    pc->writeback = writeback;
    Block_features df = device_features();
    df.config_wce() = true;
    set_device_features(df);
  }

  /** Get the writeback field from the configuration space.
   *
   * \return Value of the writeback field.
   */
  l4_uint8_t get_writeback()
  {
    l4virtio_block_config_t volatile *pc = _dev_config.priv_config();
    return pc->writeback;
  }

  /**
   * Sets constraints for and enables the discard command.
   *
   * \param max_discard_sectors Maximum discard sectors size.
   * \param max_discard_seg Maximum discard segment number.
   * \param discard_sector_alignment Can be used by the driver when splitting a
   *                                 request based on alignment.
   */
  void set_discard(l4_uint32_t max_discard_sectors, l4_uint32_t max_discard_seg,
                   l4_uint32_t discard_sector_alignment)
  {
    l4virtio_block_config_t volatile *pc = _dev_config.priv_config();
    pc->max_discard_sectors = max_discard_sectors;
    pc->max_discard_seg = max_discard_seg;
    pc->discard_sector_alignment = discard_sector_alignment;
    Block_features df = device_features();
    df.discard() = true;
    set_device_features(df);
  }

  /**
   * Sets constraints for and enables the write zeroes command.
   *
   * \param max_write_zeroes_sectors Maximum write zeroes sectors size.
   * \param max_write_zeroes_seg maximum write zeroes segment number.
   * \param write_zeroes_may_unmap Set if a write zeroes request can result in
   *                               deallocating one or more sectors.
   */
  void set_write_zeroes(l4_uint32_t max_write_zeroes_sectors,
                        l4_uint32_t max_write_zeroes_seg,
                        l4_uint8_t write_zeroes_may_unmap)
  {
    l4virtio_block_config_t volatile *pc = _dev_config.priv_config();
    pc->max_write_zeroes_sectors = max_write_zeroes_sectors;
    pc->max_write_zeroes_seg = max_write_zeroes_seg;
    pc->write_zeroes_may_unmap = write_zeroes_may_unmap;
    Block_features df = device_features();
    df.write_zeroes() = true;
    set_device_features(df);
  }

public:
  /**
   * Create a new virtio block device.
   *
   * \param vendor     Vendor ID
   * \param queue_size Number of entries to provide in avail and used queue.
   * \param capacity   Size of the device in 512-byte sectors.
   * \param read_only  True, if the device should not be writable.
   */
  Block_dev_base(l4_uint32_t vendor, unsigned queue_size, l4_uint64_t capacity,
                 bool read_only)
  : L4virtio::Svr::Device_t<Ds_data>(&_dev_config),
    _vq_max(queue_size),
    _dev_config(vendor, L4VIRTIO_ID_BLOCK, 1)
  {
    this->reset_queue_config(0, queue_size);

    Block_features df(0);
    df.ring_indirect_desc() = true;
    df.ro() = read_only;
    set_device_features(df);

    _dev_config.set_host_feature(L4VIRTIO_FEATURE_VERSION_1);

    _dev_config.priv_config()->capacity = capacity;
  }

  /**
   * Reset the actual hardware device.
   */
  virtual void reset_device() = 0;

  /**
   * Return true, if the queues should not be processed further.
   */
  virtual bool queue_stopped() = 0;

  /**
   * Releases resources related to a request and notifies the client.
   *
   * \param req     Pointer to request that has finished.
   * \param sz      Number of bytes consumed.
   * \param status  Status of request (see L4virtio_block_status).
   *
   * This function must be called when an asynchronous request finishes,
   * either successfully or with an error. The status byte in the request
   * must have been set prior to calling it.
   */
  void finalize_request(cxx::unique_ptr<Request> req, unsigned sz,
                        l4_uint8_t status = L4VIRTIO_BLOCK_S_OK)
  {
    if (_dev_config.status().fail_state() || !_queue.ready())
      return;

    if (req->release_request(&_queue, status, sz) < 0)
      this->device_error();

    if (_queue.no_notify_guest())
      return;

    _dev_config.add_irq_status(L4VIRTIO_IRQ_STATUS_VRING);
    _kick_guest_irq->trigger();

    // Request can be dropped here.
  }

  int reconfig_queue(unsigned idx) override
  {
    if (idx == 0 && this->setup_queue(&_queue, 0, _vq_max))
      return 0;

    return -L4_EINVAL;
  }

  void reset() override
  {
    _queue.disable();
    _dev_config.reset_queue(0, _vq_max);
    _dev_config.reset_hdr();
    reset_device();
  }

protected:
  bool check_for_new_requests()
  {
    if (!_queue.ready() || queue_stopped())
      return false;

    if (_dev_config.status().fail_state())
      return false;

    return _queue.desc_avail();
  }

  /// Return one request if available.
  cxx::unique_ptr<Request> get_request()
  {
    cxx::unique_ptr<Request> req;

    if (!_queue.ready() || queue_stopped())
      return req;

    if (_dev_config.status().fail_state())
      return req;

    auto r = _queue.next_avail();
    if (!r)
      return req;

    try
      {
        cxx::unique_ptr<Request> cur{
          new Request(r, &(this->_mem_info), _vq_max, _max_block_size)};

        req = cxx::move(cur);
      }
    catch (Bad_descriptor const &e)
      {
        this->device_error();
        return req;
      }

    return req;
  }

private:
  void register_single_driver_irq() override
  {
    _kick_guest_irq = L4Re::Util::Unique_cap<L4::Irq>(
      L4Re::chkcap(this->server_iface()->template rcv_cap<L4::Irq>(0)));

    L4Re::chksys(this->server_iface()->realloc_rcv_cap(0));
  }

  void trigger_driver_config_irq() override
  {
    _dev_config.add_irq_status(L4VIRTIO_IRQ_STATUS_CONFIG);
    _kick_guest_irq->trigger();
  }

  bool check_queues() override
  {
    if (!_queue.ready())
      {
        reset();
        return false;
      }

    return true;
  }
};

template <typename Ds_data>
struct Block_dev
: Block_dev_base<Ds_data>,
  L4::Epiface_t<Block_dev<Ds_data>, L4virtio::Device>
{
private:
  class Irq_object : public L4::Irqep_t<Irq_object>
  {
  public:
    Irq_object(Block_dev<Ds_data> *parent) : _parent(parent) {}

    void handle_irq()
    {
      _parent->kick();
    }

  private:
    Block_dev<Ds_data> *_parent;
  };
  Irq_object _irq_handler;

protected:
  L4::Epiface *irq_iface()
  { return &_irq_handler; }

public:
  Block_dev(l4_uint32_t vendor, unsigned queue_size, l4_uint64_t capacity,
            bool read_only)
  : Block_dev_base<Ds_data>(vendor, queue_size, capacity, read_only),
    _irq_handler(this)
  {}

  /**
   * Attach device to an object registry.
   *
   * \param registry Object registry that will be responsible for dispatching
   *                 requests.
   * \param service  Name of an existing capability the device should use.
   *
   * This functions registers the general virtio interface as well as the
   * interrupt handler which is used for receiving client notifications.
   */
  L4::Cap<void> register_obj(L4::Registry_iface *registry,
                             char const *service = 0)
  {
    L4Re::chkcap(registry->register_irq_obj(this->irq_iface()));
    L4::Cap<void> ret;
    if (service)
      ret = registry->register_obj(this, service);
    else
      ret = registry->register_obj(this);
    L4Re::chkcap(ret);

    return ret;
  }

  L4::Cap<void> register_obj(L4::Registry_iface *registry,
                             L4::Cap<L4::Rcv_endpoint> ep)
  {
    L4Re::chkcap(registry->register_irq_obj(this->irq_iface()));

    return L4Re::chkcap(registry->register_obj(this, ep));
  }

  typedef Block_request<Ds_data> Request;
  /**
   * Implements the actual processing of data in the device.
   *
   * \param req  The request to be processed.
   * \return If false, no further requests will be scheduled.
   *
   * Synchronous and asynchronous processing of the data is supported.
   * For asynchronous mode, the function should set up the worker
   * and then return false. In synchronous mode, the function should
   * return true, once processing is complete. If there is an error
   * and processing is aborted, the status flag of `req` needs to be set
   * accordingly and the request immediately finished with finish_request()
   * if the client is to be answered.
   */
  virtual bool process_request(cxx::unique_ptr<Request> &&req) = 0;

protected:
  L4::Ipc_svr::Server_iface *server_iface() const override
  {
    return this->L4::Epiface::server_iface();
  }

  void kick()
  {
    for (;;)
      {
        auto req = this->get_request();
        if (!req)
          return;
        if (!this->process_request(cxx::move(req)))
          return;
      }
  }

private:
  L4::Cap<L4::Irq> device_notify_irq() const override
  {
    return L4::cap_cast<L4::Irq>(_irq_handler.obj_cap());
  }
};

} }
