package com.sshtools.j2ssh.agent;


import java.io.*;
import java.net.*;
import com.sshtools.j2ssh.subsystem.*;
import com.sshtools.j2ssh.io.*;
import java.util.*;
import com.sshtools.j2ssh.transport.*;
import org.apache.log4j.Logger;
import com.sshtools.j2ssh.transport.publickey.*;
import java.util.Arrays;
import com.sshtools.j2ssh.connection.Channel;

public class SshAgentClient {

  InputStream in;
  OutputStream out;
  boolean isForwarded = false;

  private static Logger log = Logger.getLogger(SshAgentClient.class);

  public static final String HASH_AND_SIGN = "hash-and-sign";

  HashMap messages = new HashMap();

  public static SshAgentClient connectLocalAgent(String application, String location, int connectTimeout)
      throws AgentNotAvailableException,
      IOException {

    try {

      Socket socket = connectAgentSocket(location, connectTimeout);

      return new SshAgentClient(false,
                                application,
                                socket.getInputStream(),
                                socket.getOutputStream());
    }
    catch (SocketTimeoutException ex) {
      throw new AgentNotAvailableException();
    }

  }

  public static Socket connectAgentSocket(String location, int connectTimeout)
      throws AgentNotAvailableException,
      IOException {

    try {

      if(location==null)
        throw new AgentNotAvailableException();

      int idx = location.indexOf(":");

      if(idx==-1)
        throw new AgentNotAvailableException();

      String host = location.substring(0, idx);
      int port = Integer.parseInt(location.substring(idx+1));

      Socket socket = new Socket();
      socket.connect(new InetSocketAddress(host,
                                           port),
                     connectTimeout > 0 ? connectTimeout : 0);

      return socket;
    }
    catch (SocketTimeoutException ex) {
      throw new AgentNotAvailableException();
    }

  }

  public static SshAgentClient connectForwardedAgent(String application,
      Channel forwardedChannel) throws IOException {

    return new SshAgentClient(true,
                              application,
                              forwardedChannel.getInputStream(),
                              forwardedChannel.getOutputStream());

  }

   SshAgentClient(boolean isForwarded,
                        String application,
                        InputStream in,
                        OutputStream out) throws IOException {

    log.info("New SshAgentClient instance created");
    this.in = in;
    this.out = out;
    this.isForwarded = isForwarded;

    registerMessages();

    if(isForwarded)
      sendForwardingNotice();
    else
      sendVersionRequest(application);
  }

  public void close() throws IOException {
    in.close();
    out.close();
  }

  protected void registerMessages() {
    messages.put(
        new Integer(SshAgentVersionResponse.SSH_AGENT_VERSION_RESPONSE),
                    SshAgentVersionResponse.class);
    messages.put(
        new Integer(SshAgentSuccess.SSH_AGENT_SUCCESS),
                    SshAgentSuccess.class);
    messages.put(
        new Integer(SshAgentFailure.SSH_AGENT_FAILURE),
                    SshAgentFailure.class);

    messages.put(
        new Integer(SshAgentKeyList.SSH_AGENT_KEY_LIST),
                    SshAgentKeyList.class);

    messages.put(
        new Integer(SshAgentRandomData.SSH_AGENT_RANDOM_DATA),
                    SshAgentRandomData.class);

    messages.put(
        new Integer(SshAgentAlive.SSH_AGENT_ALIVE),
                    SshAgentAlive.class);

    messages.put(
        new Integer(SshAgentOperationComplete.SSH_AGENT_OPERATION_COMPLETE),
                    SshAgentOperationComplete.class);

  }

  protected void sendVersionRequest(String application) throws IOException {

    SubsystemMessage msg = new SshAgentRequestVersion(application);
    sendMessage(msg);

    msg = readMessage();

    if(msg instanceof SshAgentVersionResponse) {
      SshAgentVersionResponse reply = (SshAgentVersionResponse)msg;
      if(reply.getVersion()!=2)
        throw new IOException("The agent verison is not compatible with verison 2");
    } else
      throw new IOException("The agent did not respond with the appropriate version");
  }


  public void addKey(SshPrivateKey prvkey,
                     SshPublicKey pubkey,
                     String description,
                     KeyConstraints constraints) throws IOException {

    SubsystemMessage msg = new SshAgentAddKey(prvkey, pubkey, description, constraints);
    sendMessage(msg);

    msg = readMessage();

    if(!(msg instanceof SshAgentSuccess)) {
      throw new IOException("The key could not be added");
    }
  }

  public byte[] hashAndSign(SshPublicKey key, byte[] data) throws IOException {

    SubsystemMessage msg = new SshAgentPrivateKeyOp(key, HASH_AND_SIGN, data);

    sendMessage(msg);

    msg = readMessage();

    if(msg instanceof SshAgentOperationComplete) {
      return ((SshAgentOperationComplete)msg).getData();
    } else
      throw new IOException("The operation failed");
  }

  public Map listKeys() throws IOException {

    SubsystemMessage msg = new SshAgentListKeys();
    sendMessage(msg);

    msg = readMessage();

    if(msg instanceof SshAgentKeyList) {
      return ((SshAgentKeyList)msg).getKeys();
    } else
      throw new IOException("The agent responsed with an invalid message");


  }

  public boolean lockAgent(String password) throws IOException{

    SubsystemMessage msg = new SshAgentLock(password);
    sendMessage(msg);

    msg = readMessage();

    return (msg instanceof SshAgentSuccess);

  }

  public boolean unlockAgent(String password) throws IOException {

    SubsystemMessage msg = new SshAgentUnlock(password);
    sendMessage(msg);

    msg = readMessage();

    return (msg instanceof SshAgentSuccess);


  }

  public byte[] getRandomData(int count) throws IOException {
    SubsystemMessage msg = new SshAgentRandom(count);
    sendMessage(msg);

    msg = readMessage();

    if(msg instanceof SshAgentRandomData) {
      return ((SshAgentRandomData)msg).getRandomData();
    } else
      throw new IOException("Agent failed to provide the request random data");

  }

  public void ping(byte[] padding) throws IOException {

    SubsystemMessage msg = new SshAgentPing(padding);
    sendMessage(msg);

    msg = readMessage();

    if(msg instanceof SshAgentAlive) {
      if(!Arrays.equals(padding, ((SshAgentAlive)msg).getPadding()))
      throw new IOException("Agent failed to reply with expected data");
    } else
      throw new IOException("Agent failed to provide the request random data");

  }

  public void deleteKey(SshPublicKey key, String description) throws IOException {

    SubsystemMessage msg = new SshAgentDeleteKey(key, description);
    sendMessage(msg);

    msg = readMessage();

    if(!(msg instanceof SshAgentSuccess))
      throw new IOException("The agent failed to delete the key");
  }

  public void deleteAllKeys() throws IOException {

    SubsystemMessage msg = new SshAgentDeleteAllKeys();
    sendMessage(msg);

    msg = readMessage();

    if(!(msg instanceof SshAgentSuccess))
      throw new IOException("The agent failed to delete all keys");
  }

  protected void sendForwardingNotice() throws IOException {
    InetAddress addr = InetAddress.getLocalHost();

    SshAgentForwardingNotice msg = new SshAgentForwardingNotice(addr.getHostName(),
        addr.getHostAddress(),
        22);

    sendMessage(msg);
  }

  protected void sendMessage(SubsystemMessage msg)throws IOException {
     log.info("Sending message " + msg.getMessageName());
     byte[] msgdata = msg.toByteArray();
     out.write(ByteArrayWriter.encodeInt(msgdata.length));
     out.write(msgdata);
     out.flush();
   }

   protected SubsystemMessage readMessage() throws InvalidMessageException {

   try {
     byte[] lendata = new byte[4];
     byte[] msgdata;
     int len;

     // Read the first 4 bytes to determine the length of the message
     len = 0;
     while(len < 3)
       len += in.read(lendata, len, lendata.length - len);

     len = (int)ByteArrayReader.readInt(lendata, 0);
     msgdata = new byte[len];
     len = 0;
     while(len < msgdata.length)
       len += in.read(msgdata, len, msgdata.length - len);

     Integer id = new Integer((int)msgdata[0] & 0xFF);
     if(messages.containsKey(id)) {

        Class cls = (Class) messages.get(id);
        SubsystemMessage msg = (SubsystemMessage) cls.newInstance();
        msg.fromByteArray(msgdata);
        log.info("Received message " + msg.getMessageName());
        return msg;
     } else
       throw new InvalidMessageException("Unrecognised message id "+ id.toString());
   } catch (Exception ex) {
       throw new InvalidMessageException(ex.getMessage());
     }




   }

}