/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.client.kex;

import java.math.BigInteger;
import java.util.Objects;
import org.apache.sshd.client.kex.AbstractDHClientKeyExchange;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.config.keys.KeyUtils;
import org.apache.sshd.common.kex.AbstractDH;
import org.apache.sshd.common.kex.DHFactory;
import org.apache.sshd.common.kex.KeyExchange;
import org.apache.sshd.common.kex.KeyExchangeFactory;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.session.helpers.AbstractSession;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.apache.sshd.common.util.security.SecurityUtils;

public class DHGEXClient
extends AbstractDHClientKeyExchange {
    protected final DHFactory factory;
    protected byte expected;
    protected int min = 1024;
    protected int prf;
    protected int max;
    protected AbstractDH dh;
    protected byte[] p;
    protected byte[] g;

    protected DHGEXClient(DHFactory factory) {
        this.factory = Objects.requireNonNull(factory, "No factory");
        this.max = SecurityUtils.getMaxDHGroupExchangeKeySize();
        this.prf = Math.min(4096, this.max);
    }

    @Override
    public final String getName() {
        return this.factory.getName();
    }

    public static KeyExchangeFactory newFactory(final DHFactory delegate) {
        return new KeyExchangeFactory(){

            @Override
            public String getName() {
                return delegate.getName();
            }

            @Override
            public KeyExchange create() {
                return new DHGEXClient(delegate);
            }

            public String toString() {
                return NamedFactory.class.getSimpleName() + "<" + KeyExchange.class.getSimpleName() + ">[" + this.getName() + "]";
            }
        };
    }

    @Override
    public void init(Session s, byte[] v_s, byte[] v_c, byte[] i_s, byte[] i_c) throws Exception {
        super.init(s, v_s, v_c, i_s, i_c);
        if (this.log.isDebugEnabled()) {
            this.log.debug("init({}) Send SSH_MSG_KEX_DH_GEX_REQUEST", (Object)s);
        }
        Buffer buffer = s.createBuffer((byte)34, 32);
        buffer.putInt(this.min);
        buffer.putInt(this.prf);
        buffer.putInt(this.max);
        s.writePacket(buffer);
        this.expected = (byte)31;
    }

    @Override
    public boolean next(int cmd, Buffer buffer) throws Exception {
        AbstractSession session = this.getSession();
        if (this.log.isDebugEnabled()) {
            this.log.debug("next({})[{}] process command={}", new Object[]{this, session, KeyExchange.getGroupKexOpcodeName(cmd)});
        }
        if (cmd != this.expected) {
            throw new SshException(3, "Protocol error: expected packet " + KeyExchange.getGroupKexOpcodeName(this.expected) + ", got " + KeyExchange.getGroupKexOpcodeName(cmd));
        }
        if (cmd == 31) {
            this.p = buffer.getMPIntAsBytes();
            this.g = buffer.getMPIntAsBytes();
            this.dh = this.getDH(new BigInteger(this.p), new BigInteger(this.g));
            this.hash = this.dh.getHash();
            this.hash.init();
            this.e = this.dh.getE();
            if (this.log.isDebugEnabled()) {
                this.log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_INIT", (Object)this, (Object)session);
            }
            buffer = session.createBuffer((byte)32, this.e.length + 8);
            buffer.putMPInt(this.e);
            session.writePacket(buffer);
            this.expected = (byte)33;
            return false;
        }
        if (cmd == 33) {
            byte[] k_s = buffer.getBytes();
            this.f = buffer.getMPIntAsBytes();
            byte[] sig = buffer.getBytes();
            this.dh.setF(this.f);
            this.k = this.dh.getK();
            buffer = new ByteArrayBuffer(k_s);
            this.serverKey = buffer.getRawPublicKey();
            String keyAlg = KeyUtils.getKeyType(this.serverKey);
            if (GenericUtils.isEmpty(keyAlg)) {
                throw new SshException("Unsupported server key type");
            }
            buffer = new ByteArrayBuffer();
            buffer.putBytes(this.v_c);
            buffer.putBytes(this.v_s);
            buffer.putBytes(this.i_c);
            buffer.putBytes(this.i_s);
            buffer.putBytes(k_s);
            buffer.putInt(this.min);
            buffer.putInt(this.prf);
            buffer.putInt(this.max);
            buffer.putMPInt(this.p);
            buffer.putMPInt(this.g);
            buffer.putMPInt(this.e);
            buffer.putMPInt(this.f);
            buffer.putMPInt(this.k);
            this.hash.update(buffer.array(), 0, buffer.available());
            this.h = this.hash.digest();
            Signature verif = ValidateUtils.checkNotNull(NamedFactory.create(session.getSignatureFactories(), keyAlg), "No verifier located for algorithm=%s", (Object)keyAlg);
            verif.initVerifier(this.serverKey);
            verif.update(this.h);
            if (!verif.verify(sig)) {
                throw new SshException(3, "KeyExchange signature verification failed for key type=" + keyAlg);
            }
            return true;
        }
        throw new IllegalStateException("Unknown command value: " + KeyExchange.getGroupKexOpcodeName(cmd));
    }

    protected AbstractDH getDH(BigInteger p, BigInteger g) throws Exception {
        return this.factory.create(p, g);
    }
}

