/*
 *  Sshtools - Java SSH2 API
 *
 *  Copyright (C) 2002 Lee David Painter.
 *
 *  Written by: 2002 Lee David Painter <lee@sshtools.com>
 *
 *  This program is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Library General Public License
 *  as published by the Free Software Foundation; either version 2 of
 *  the License, or (at your option) any later version.
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Library General Public License for more details.
 *
 *  You should have received a copy of the GNU Library General Public
 *  License along with this program; if not, write to the Free Software
 *  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package com.sshtools.j2ssh.connection;

import com.sshtools.j2ssh.transport.AsyncService;
import com.sshtools.j2ssh.transport.MessageAlreadyRegisteredException;
import com.sshtools.j2ssh.transport.MessageStoreEOFException;
import com.sshtools.j2ssh.transport.ServiceOperationException;
import com.sshtools.j2ssh.transport.SshMessage;
import com.sshtools.j2ssh.transport.TransportProtocolException;
import com.sshtools.j2ssh.transport.TransportProtocolState;

import com.sshtools.j2ssh.util.InvalidStateException;

import org.apache.log4j.Logger;

import java.io.IOException;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;


public class ConnectionProtocol extends AsyncService {
    private static Logger log = Logger.getLogger(ConnectionProtocol.class);
    private Map activeChannels = new HashMap();
    private Map allowedChannels = new HashMap();
    private Map globalRequests = new HashMap();
    private long nextChannelId = 0;

    public ConnectionProtocol() {
        super("ssh-connection");
    }

    public void addChannelFactory(String channelName, ChannelFactory cf) throws IOException {
       allowedChannels.put(channelName, cf);
    }

    public void removeChannelFactory(String channelName) {
      allowedChannels.remove(channelName);
    }

    public boolean containsChannelFactory(String channelName) {
      return allowedChannels.containsKey(channelName);
    }

    public void allowGlobalRequest(String requestName,
        GlobalRequestHandler handler) {
        globalRequests.put(requestName, handler);
    }

    public synchronized boolean openChannel(Channel channel)
        throws IOException {
        return openChannel(channel, null);
    }

    public boolean isConnected() {
      return (transport.getState().getValue()
              ==TransportProtocolState.CONNECTED ||
              transport.getState().getValue()
              ==TransportProtocolState.PERFORMING_KEYEXCHANGE);
    }

    public synchronized boolean openChannel(Channel channel,
        ChannelEventListener eventListener) throws IOException {
        long channelId = nextChannelId++;

        synchronized (activeChannels) {
            // Create the message
            SshMsgChannelOpen msg = new SshMsgChannelOpen(channel
                    .getChannelType(), channelId,
                    channel.getLocalWindow().getWindowSpace(),
                    channel.getLocalPacketSize(), channel.getChannelOpenData());

            // Send the message
            transport.sendMessage(msg, this);

            // Wait for the next message to confirm the open channel (or not)
            int[] messageIdFilter = new int[2];
            messageIdFilter[0] = SshMsgChannelOpenConfirmation.SSH_MSG_CHANNEL_OPEN_CONFIRMATION;

            messageIdFilter[1] = SshMsgChannelOpenFailure.SSH_MSG_CHANNEL_OPEN_FAILURE;

            try {
                SshMessage result = messageStore.getMessage(messageIdFilter);

                if (result.getMessageId() == SshMsgChannelOpenConfirmation.SSH_MSG_CHANNEL_OPEN_CONFIRMATION) {

                    SshMsgChannelOpenConfirmation conf = (SshMsgChannelOpenConfirmation) result;
                    activeChannels.put(new Long(channelId), channel);

                    log.debug("Initiating channel");
                    channel.init(this, channelId, conf.getSenderChannel(),
                        conf.getInitialWindowSize(),
                        conf.getMaximumPacketSize(), eventListener);

                    channel.open();
                    log.info("Channel is open");

                    return true;
                } else {
                    // Make sure the channels state is closed
                    channel.getState().setValue(ChannelState.CHANNEL_CLOSED);

                    return false;
                }
            } catch (MessageStoreEOFException mse) {
                throw new ServiceOperationException(mse.getMessage());
            }
        }
    }

    protected void onStop() {
    }

    public synchronized void sendChannelData(Channel channel, byte[] data)
        throws IOException {
        if (log.isDebugEnabled()) {
            log.debug("Sending " + String.valueOf(data.length)
                + " bytes for channel id "
                + String.valueOf(channel.getLocalChannelId()));
        }

        SshMsgChannelData msg = new SshMsgChannelData(channel
                .getRemoteChannelId(), data);

        transport.sendMessage(msg, this);
    }

    public synchronized void sendChannelEOF(Channel channel)
        throws IOException {
        synchronized (activeChannels) {
            if (!activeChannels.containsValue(channel)) {
                throw new ServiceOperationException(
                    "Attempt to send EOF for a non existent channel "
                    + String.valueOf(channel.getLocalChannelId()));
            }

            SshMsgChannelEOF msg = new SshMsgChannelEOF(channel
                    .getRemoteChannelId());

            transport.sendMessage(msg, this);
        }
    }

    public synchronized void sendChannelExtData(Channel channel,
        int extendedType, byte[] data) throws IOException {
        channel.getRemoteWindow().consumeWindowSpace(data.length);

        SshMsgChannelExtendedData msg = new SshMsgChannelExtendedData(channel
                .getRemoteChannelId(), extendedType, data);

        transport.sendMessage(msg, this);
    }

    public synchronized boolean sendChannelRequest(Channel channel,
        String requestType, boolean wantReply, byte[] requestData)
        throws IOException {
        boolean success = true;

        log.debug("Sending " + requestType + " request for the "
            + channel.getChannelType() + " channel");

        SshMsgChannelRequest msg = new SshMsgChannelRequest(channel
                .getRemoteChannelId(), requestType, wantReply, requestData);

        transport.sendMessage(msg, this);

        // If the user requests a reply then wait for the message and return result
        if (wantReply) {
            // Set up our message filter
            int[] messageIdFilter = new int[2];
            messageIdFilter[0] = SshMsgChannelSuccess.SSH_MSG_CHANNEL_SUCCESS;
            messageIdFilter[1] = SshMsgChannelFailure.SSH_MSG_CHANNEL_FAILURE;

            log.debug("Waiting for channel request reply");

            // Wait for either success or failure
            SshMessage reply = messageStore.getMessage(messageIdFilter);

            switch (reply.getMessageId()) {
            case SshMsgChannelSuccess.SSH_MSG_CHANNEL_SUCCESS: {
                log.debug("Channel request succeeded");
                success = true;

                break;
            }

            case SshMsgChannelFailure.SSH_MSG_CHANNEL_FAILURE: {
                log.debug("Channel request failed");
                success = false;

                break;
            }
            }
        }

        return success;
    }

    public synchronized void sendChannelRequestFailure(Channel channel)
        throws IOException {
        SshMsgChannelFailure msg = new SshMsgChannelFailure(channel
                .getRemoteChannelId());

        transport.sendMessage(msg, this);
    }

    public synchronized void sendChannelRequestSuccess(Channel channel)
        throws IOException {
        SshMsgChannelSuccess msg = new SshMsgChannelSuccess(channel
                .getRemoteChannelId());

        transport.sendMessage(msg, this);
    }

    public synchronized void sendChannelWindowAdjust(Channel channel,
        long bytesToAdd) throws IOException {
        SshMsgChannelWindowAdjust msg = new SshMsgChannelWindowAdjust(channel
                .getRemoteChannelId(), bytesToAdd);

        transport.sendMessage(msg, this);
    }

    public synchronized boolean sendGlobalRequest(String requestName,
        boolean wantReply, byte[] requestData) throws IOException {
        boolean success = true;

        SshMsgGlobalRequest msg = new SshMsgGlobalRequest(requestName, true,
                requestData);

        transport.sendMessage(msg, this);

        if (wantReply) {
            // Set up our message filter
            int[] messageIdFilter = new int[2];
            messageIdFilter[0] = SshMsgRequestSuccess.SSH_MSG_REQUEST_SUCCESS;
            messageIdFilter[1] = SshMsgRequestFailure.SSH_MSG_REQUEST_FAILURE;

            log.debug("Waiting for global request reply");

            // Wait for either success or failure
            SshMessage reply = messageStore.getMessage(messageIdFilter);

            switch (reply.getMessageId()) {
            case SshMsgRequestSuccess.SSH_MSG_REQUEST_SUCCESS: {
                log.debug("Global request succeeded");
                success = true;

                break;
            }

            case SshMsgRequestFailure.SSH_MSG_REQUEST_FAILURE: {
                log.debug("Global request failed");
                success = false;

                break;
            }
            }
        }

        return success;
    }

    protected int[] getAsyncMessageFilter() {
        int[] messageFilter = new int[10];

        messageFilter[0] = SshMsgGlobalRequest.SSH_MSG_GLOBAL_REQUEST;
        messageFilter[3] = SshMsgChannelOpen.SSH_MSG_CHANNEL_OPEN;
        messageFilter[4] = SshMsgChannelClose.SSH_MSG_CHANNEL_CLOSE;
        messageFilter[5] = SshMsgChannelEOF.SSH_MSG_CHANNEL_EOF;
        messageFilter[6] = SshMsgChannelExtendedData.SSH_MSG_CHANNEL_EXTENDED_DATA;
        messageFilter[7] = SshMsgChannelData.SSH_MSG_CHANNEL_DATA;
        messageFilter[8] = SshMsgChannelRequest.SSH_MSG_CHANNEL_REQUEST;
        messageFilter[9] = SshMsgChannelWindowAdjust.SSH_MSG_CHANNEL_WINDOW_ADJUST;

        return messageFilter;
    }

    protected synchronized void closeChannel(Channel channel)
        throws IOException {
        SshMsgChannelClose msg = new SshMsgChannelClose(channel
                .getRemoteChannelId());

        transport.sendMessage(msg, this);
    }

    protected void onGlobalRequest(String requestName, boolean wantReply,
        byte[] requestData) throws IOException {
        log.debug("Processing " + requestName + " global request");

        if (!globalRequests.containsKey(requestName)) {
            sendGlobalRequestFailure();
        } else {
            GlobalRequestHandler handler = (GlobalRequestHandler) globalRequests
                .get(requestName);

            GlobalRequestResponse response = handler.processGlobalRequest(requestName,
                    requestData);

            if (wantReply) {
                if (response.hasSucceeded()) {
                    sendGlobalRequestSuccess(response.getResponseData());
                } else {
                    sendGlobalRequestFailure();
                }
            }
        }
    }

    protected void onMessageReceived(SshMessage msg) throws IOException {
        // Route the message to the correct handling function
        switch (msg.getMessageId()) {
        case SshMsgGlobalRequest.SSH_MSG_GLOBAL_REQUEST: {
            onMsgGlobalRequest((SshMsgGlobalRequest) msg);

            break;
        }

        case SshMsgChannelOpen.SSH_MSG_CHANNEL_OPEN: {
            onMsgChannelOpen((SshMsgChannelOpen) msg);

            break;
        }

        case SshMsgChannelClose.SSH_MSG_CHANNEL_CLOSE: {
            onMsgChannelClose((SshMsgChannelClose) msg);

            break;
        }

        case SshMsgChannelEOF.SSH_MSG_CHANNEL_EOF: {
            onMsgChannelEOF((SshMsgChannelEOF) msg);

            break;
        }

        case SshMsgChannelData.SSH_MSG_CHANNEL_DATA: {
            onMsgChannelData((SshMsgChannelData) msg);

            break;
        }

        case SshMsgChannelExtendedData.SSH_MSG_CHANNEL_EXTENDED_DATA: {
            onMsgChannelExtendedData((SshMsgChannelExtendedData) msg);

            break;
        }

        case SshMsgChannelRequest.SSH_MSG_CHANNEL_REQUEST: {
            onMsgChannelRequest((SshMsgChannelRequest) msg);

            break;
        }

        case SshMsgChannelWindowAdjust.SSH_MSG_CHANNEL_WINDOW_ADJUST: {
            onMsgChannelWindowAdjust((SshMsgChannelWindowAdjust) msg);

            break;
        }

        default: {
            // If we never registered it why are we getting it?
            log.debug("Message not handled");
            throw new ServiceOperationException(
                "Unregistered message received!");
        }
        }
    }

    protected void onServiceAccept() {
    }

    protected void onServiceInit(int startMode) throws IOException {
        log.info("Registering connection protocol messages");

        messageStore.registerMessage(SshMsgChannelOpenConfirmation.SSH_MSG_CHANNEL_OPEN_CONFIRMATION,
            SshMsgChannelOpenConfirmation.class);

        messageStore.registerMessage(SshMsgChannelOpenFailure.SSH_MSG_CHANNEL_OPEN_FAILURE,
            SshMsgChannelOpenFailure.class);

        messageStore.registerMessage(SshMsgChannelOpen.SSH_MSG_CHANNEL_OPEN,
            SshMsgChannelOpen.class);

        messageStore.registerMessage(SshMsgChannelClose.SSH_MSG_CHANNEL_CLOSE,
            SshMsgChannelClose.class);

        messageStore.registerMessage(SshMsgChannelEOF.SSH_MSG_CHANNEL_EOF,
            SshMsgChannelEOF.class);

        messageStore.registerMessage(SshMsgChannelData.SSH_MSG_CHANNEL_DATA,
            SshMsgChannelData.class);

        messageStore.registerMessage(SshMsgChannelExtendedData.SSH_MSG_CHANNEL_EXTENDED_DATA,
            SshMsgChannelExtendedData.class);

        messageStore.registerMessage(SshMsgChannelFailure.SSH_MSG_CHANNEL_FAILURE,
            SshMsgChannelFailure.class);

        messageStore.registerMessage(SshMsgChannelRequest.SSH_MSG_CHANNEL_REQUEST,
            SshMsgChannelRequest.class);

        messageStore.registerMessage(SshMsgChannelSuccess.SSH_MSG_CHANNEL_SUCCESS,
            SshMsgChannelSuccess.class);

        messageStore.registerMessage(SshMsgChannelWindowAdjust.SSH_MSG_CHANNEL_WINDOW_ADJUST,
            SshMsgChannelWindowAdjust.class);

        messageStore.registerMessage(SshMsgGlobalRequest.SSH_MSG_GLOBAL_REQUEST,
            SshMsgGlobalRequest.class);

        messageStore.registerMessage(SshMsgRequestFailure.SSH_MSG_REQUEST_FAILURE,
            SshMsgRequestFailure.class);

        messageStore.registerMessage(SshMsgRequestSuccess.SSH_MSG_REQUEST_SUCCESS,
            SshMsgRequestSuccess.class);
    }

    protected void onServiceRequest() {
    }

    protected void sendChannelFailure(Channel channel)
        throws IOException {
        SshMsgChannelFailure msg = new SshMsgChannelFailure(channel
                .getRemoteChannelId());

        transport.sendMessage(msg, this);
    }

    protected void sendChannelOpenConfirmation(Channel channel)
        throws IOException {
        SshMsgChannelOpenConfirmation msg = new SshMsgChannelOpenConfirmation(channel
                .getRemoteChannelId(), channel.getLocalChannelId(),
                channel.getLocalWindow().getWindowSpace(),
                channel.getLocalPacketSize(),
                channel.getChannelConfirmationData());

        transport.sendMessage(msg, this);
    }

    protected void sendChannelOpenFailure(long remoteChannelId,
        long reasonCode, String additionalInfo, String languageTag)
        throws IOException {
        SshMsgChannelOpenFailure msg = new SshMsgChannelOpenFailure(remoteChannelId,
                reasonCode, additionalInfo, languageTag);

        transport.sendMessage(msg, this);
    }

    protected void sendGlobalRequestFailure() throws IOException {
        SshMsgRequestFailure msg = new SshMsgRequestFailure();

        transport.sendMessage(msg, this);
    }

    protected void sendGlobalRequestSuccess(byte[] requestData)
        throws IOException {
        SshMsgRequestSuccess msg = new SshMsgRequestSuccess(requestData);

        transport.sendMessage(msg, this);
    }

    private Channel getChannel(long channelId) throws ServiceOperationException {
        synchronized (activeChannels) {
            Long l = new Long(channelId);

            if (!activeChannels.containsKey(l)) {
                throw new ServiceOperationException("Non existent channel "
                    + l.toString() + " requested");
            }

            return (Channel) activeChannels.get(l);
        }
    }

    private void onMsgChannelClose(SshMsgChannelClose msg)
        throws IOException {
        Channel channel = getChannel(msg.getRecipientChannel());

        // If we have not already closed it then inform the subclasses
        if (channel == null) {
            throw new ServiceOperationException(
                "Remote computer tried to close a " + "non existent channel!");
        }

        // If the channel is not already closed then close it
        if (channel.getState().getValue() != ChannelState.CHANNEL_CLOSED) {
            channel.close();
        }

        // Remove the channel
        removeChannel(channel);
    }

    private void onMsgChannelData(SshMsgChannelData msg)
        throws IOException {
        if (log.isDebugEnabled()) {
            log.debug("Received " + String.valueOf(msg.getChannelData().length)
                + " bytes of data for channel id "
                + String.valueOf(msg.getRecipientChannel()));
        }

        // Get the data's channel
        Channel channel = getChannel(msg.getRecipientChannel());

        if (channel == null) {
            throw new ServiceOperationException(
                "Remote computer sent data for non existent channel");
        }

        channel.onChannelData(msg);
    }

    private void onMsgChannelEOF(SshMsgChannelEOF msg)
        throws IOException {


        Channel channel = getChannel(msg.getRecipientChannel());

        if (channel == null) {
            throw new ServiceOperationException(
                "Remote side tried to set a non " + "existent channel to EOF!");
        }


        try {

            channel.onInputStreamEOF();

        } catch (IOException ioe) {
            throw new ServiceOperationException(
                "Failed to close the ChannelInputStream");
        }
    }

    private void onMsgChannelExtendedData(SshMsgChannelExtendedData msg)
        throws IOException {
        Channel channel = getChannel(msg.getRecipientChannel());

        if (channel == null) {
            throw new ServiceOperationException(
                "Remote computer sent data for non existent channel");
        }

        channel.getLocalWindow().consumeWindowSpace(msg.getChannelData().length);

        channel.incoming.addMessage(msg);
    }

    private void onMsgChannelOpen(SshMsgChannelOpen msg)
        throws IOException {
        synchronized (activeChannels) {
            // Try to get the channel implementation from the allowed channels
            ChannelFactory cf = (ChannelFactory) allowedChannels.get(msg
                    .getChannelType());

            if (cf == null) {
                sendChannelOpenFailure(msg.getSenderChannelId(),
                    SshMsgChannelOpenFailure.SSH_OPEN_CONNECT_FAILED,
                    "The channel type is not supported", "");
                log.info("Request for channel type " + msg.getChannelType()
                    + " refused");

                return;
            }

            try {
                Channel channel = cf.createChannel(msg.getChannelType(),
                        msg.getChannelData());

                // Initialize the channel
                channel.init(this,
                    nextChannelId++, msg.getSenderChannelId(),
                    msg.getInitialWindowSize(), msg.getMaximumPacketSize());

                activeChannels.put(new Long(channel.getLocalChannelId()),
                           channel);

                // Send the confirmation message
                sendChannelOpenConfirmation(channel);

                // Open the channel for real
                channel.open();

              } catch (InvalidChannelException ice) {
                sendChannelOpenFailure(msg.getSenderChannelId(),
                    SshMsgChannelOpenFailure.SSH_OPEN_CONNECT_FAILED,
                    ice.getMessage(), "");
            }



        }
    }

    private void onMsgChannelRequest(SshMsgChannelRequest msg)
        throws IOException {
        Channel channel = getChannel(msg.getRecipientChannel());

        if (channel == null) {
            log.warn("Remote computer tried to make a request for "
                + "a non existence channel!");
        }

        channel.onChannelRequest(msg.getRequestType(), msg.getWantReply(),
            msg.getChannelData());
    }

    private void onMsgChannelWindowAdjust(SshMsgChannelWindowAdjust msg)
        throws IOException {
        Channel channel = getChannel(msg.getRecipientChannel());

        if (channel == null) {
            throw new ServiceOperationException(
                "Remote computer tried to increase "
                + "window space for a non existent channel!");
        }

        channel.getRemoteWindow().increaseWindowSpace(msg.getBytesToAdd());
        log.debug(String.valueOf(msg.getBytesToAdd())
            + " bytes added to remote window");
        log.debug("Remote window space is "
            + String.valueOf(channel.getRemoteWindow().getWindowSpace()));
    }

    private void onMsgGlobalRequest(SshMsgGlobalRequest msg)
        throws IOException {
        onGlobalRequest(msg.getRequestName(), msg.getWantReply(),
            msg.getRequestData());
    }

    private void removeChannel(Channel channel) {
        synchronized (activeChannels) {
            activeChannels.remove(new Long(channel.getLocalChannelId()));
        }
    }
}
